The owl of Minerva spreads its wings only with the falling of the dusk

Markdown与其支持的的LaTeX
Markdown
- 圆点突出:- +文字
- 序号列表:数字+.
- 链接:[文本](链接)
- 图片:
- 斜体:*文字*
- 加粗:**文字**
- 粗斜体:***文字***
- 分割线:*********或者----------
采用>+文字
几个#就代表几级标题
反引号
import numpy as np
利用三个反引号,后面加上代码使用的语言,如python可实现高亮
表格
| 表头 | 表头 |
| ---- | ---- |
| 单元格 | 单元格 |
| 单元格 | 单元格 |
| 表头 | 表头 |
|---|---|
| 单元格 | 单元格 |
| 单元格 | 单元格 |
还可设置表格的对齐方式
- :- 实现左对齐
- -: 实现右对齐
- :-: 实现居中
| 左对齐 | 右对齐 | 居中对齐 |
| :-----| ----: | :----: |
| 1 | 1 | 1 |
| 左对齐 | 右对齐 | 居中对齐 |
|---|---|---|
| 1 | 1 | 1 |
markdown中的LaTeX
支持的写法详情见katex官方文档
多行公式
为公式加编号:\tag{number}
在VSCODE中必须使用aligned环境,如:
$$
\begin{aligned}
a&=1\\
&=2
\end{aligned}
$$
分段函数
使用cases环境
$$
f(x) =
\begin{cases}
a, a<1,\\
b, a\geq1.
\end{cases}
$$
效果如下
文内链接
[能够点击的链接](#name)
<div id="name"></div>
python
numpy
import numpy as np
基础
- 数组:np.array([[1,2,3],[4,5,6],...])
- 序列:np.arange(a,b,c) 初值a,终值b,步长c,不含终值
- 序列:np.linspace(a,b,c) 初值a,终值b,个数c,含终值
- 空数组:np.empty((a,b),np.int) shape&type
- 零数组:np.zeros((a,b),np.int)
- 单位阵:np.eye(N,M=None,k=0) 行数N,列数M,k为对角线上移(正)或下移(负)
- 全1阵:np.ones((a,b),np.int)
- 转置与内积:np.dot(A.T,A)
切片与索引
切片是视图而非副本,若要副本:arr[5:8].copy()
布尔型索引:data[data<0]=0, ~可用来反转条件
通用函数func
一元函数:np. f (arr)
- abs,fabs,sqrt开根,square平方
- exp,log,log10,log2
- sign,ceil向上取整,floor向下取整,rint四舍五入,modf拆成整数和小数
- isnan,isfinite,isinf
- cos,sin,cosh,sinh,tan,tanh
二元函数:np. f (arr)
- add,substract,multiply,divide,floor_divide 除后取整
- power,maximum,fmax,mod
- copysign 得到第二个数组的符号
- greater,greater_equal,less,less_equal,equal,not_equal返回布尔值
- meshgrid 接受两个一维数组,产生两个二维数组,对应所有(x,y)对
np.where()是 x if condition else y的矢量化版本 np.where(arr>2,2,-2) np.where(arr>2,2,arr)
数组统计方法
- sum,mean,std,var,min,max
- argmax,argmin 索引,cumsum,cumprod
- 查询数组中是否有true:all,any(示例:
(a=b).all())
排序
sort
集合运算,数字1
- unique(x)
- intersect1d(x,y) 交集
- union1d(x,y) 并集
- in1d(x,y) 包含于
- setdiff1d(x,y) 差集
- setxor1d(x,y) 对称差
常用numpy.linalg函数(npl)
- diag 对角阵和一维数组转化
- dot,trace,det
- eig 特征值特征向量
- inv 逆
- pinv Moore-Penrose 伪逆
- qr QR分解
- svd 奇异值分解
- solve 解Ax=b ,A方针
- lstsq Ax=b最小二乘解
部分numpy.random函数
- seed 确定随机数生成器种子
- permutation 返回新的打乱的x,x不变
- shuffle 原地打乱x
- rand 均匀分布
- randint 给定范围内随机取整数
- randn 标准正态分布
- binomial 二项分布
- normal 正态分布
- beta,gamma
- chisquare 卡方分布
- uniform [0,1)均匀分布
数组合并与拼接
- append(arr,values,axis=None)
Pandas
Series([1, 2, 3, 4], index=[a, b, c, d]) 查缺失数据 .isnull(), .notnull(), 返回同结构的布尔值 Series对象本身及其索引有个 name 属性
a = pd.Series([1 ,2 ,3 ,4], index=['a', 'b', 'c', 'd'])
a.name = 'series'
将序列作为 DataFrame 的一列时,name属性就变为那一列的列名
DataFrame
data = np.ones([3,4])
d = pd.DataFrame(data, index=['a','b','c'], columns=['a','b','c','d'])
.head() 取前五行,.tail() .del() 删除某一列,.drop()删除指定轴上某些项 .append() .difference() .intersection() .union()
索引 用标签名 data.loc['a',['c','d']] 不用标签名 data.iloc[2,[2,3]]
常用方法 .cumsum() .cumprod() .diff() .pct_change()
换指定列名 d=d.rename( index={1:'new'}, columns={'a':'shit'} )
matplotlib
import matplotlib.pyplot as plt
data1 = np.linspace(1,200,2000)
data2 = np.random.randn(2000)
fig = plt.figure()
ax1 = fig.add_subplot(2,2,1)
plt.plot(data1, label='first')
plt.plot(data2,'.', label='second')
ax1.set_xticks([1,2,40])
ax1.legend(loc='best')
ax1.set_title('first plot')
ax1.set_xlabel('index')
ax2 = fig.add_subplot(2,2,2)
'-'实线 '--'短划线 '-.'点划线 ':'虚线 '.'点 'v'倒三角 等 color参数 'b'蓝 'g'绿 'r'红 'c'青 'm'品红 'y'黄 'k'黑 'w'白
python爬虫_BeautifulSoup
获取浏览器模拟头部信息
- 浏览器输入网址
- F12 网络(network)
- 随便点一个,最下方请求标头‘User-Agent’部分
- 复制到脚本中head = {‘User-Agent’:xxxx}
# 查看头部
python正则表达式
正则表达式基础
语法
-
普通字符:
-
-
普通字符包括没有显式指定为元字符的所有可打印和不可打印字符。这包括所有大写和小写字母、所有数字、所有标点符号和一些其他符号。
-
[ABC]:匹配 [...] 中的所有字符,例如 [aeiou] 匹配字符串 "google runoob taobao" 中所有的 e o u a 字母
-
1:匹配除了 [...] 中字符的所有字符,例如 2 匹配字符串 "google runoob taobao" 中除了 e o u a 字母的所有字母
-
[A-Z]:[A-Z] 表示一个区间,匹配所有大写字母,[a-z] 表示所有小写字母
-
. :匹配除换行符**(\n、\r)之外的任何单个字符,相等于[ ^\n\r]**
-
[\s\S]:匹配所有。\s 是匹配所有空白符,包括换行,\S 非空白符,不包括换行
-
\w:匹配字母、数字、下划线。等价于 [A-Za-z0-9_]
-
-
非打印字符:
-
-
非打印字符也可以是正则表达式的组成部分。下表列出了表示非打印字符的转义序列
-
\cx:匹配由x指明的控制字符。例如, \cM 匹配一个 Control-M 或回车符。x 的值必须为 A-Z 或 a-z 之一。否则,将 c 视为一个原义的 'c' 字符
-
\f:匹配一个换页符。等价于 \x0c 和 \cL
-
\n:匹配一个换行符。等价于 \x0a 和 \cJ
-
\r:匹配一个回车符。等价于 \x0d 和 \cM
-
\s:匹配任何空白字符,包括空格、制表符、换页符等等。等价于 [ \f\n\r\t\v]。注意 Unicode 正则表达式会匹配全角空格符
-
\S:匹配任何非空白字符。等价于 [ ^ \f\n\r\t\v]
-
\t:匹配一个制表符。等价于 \x09 和 \cI
-
\v:匹配一个垂直制表符。等价于 \x0b 和 \cK
-
-
特殊字符:
-
-
$:匹配输入字符串的结尾位置。要匹配 **** 字符本身,请使用 \$
-
( ):标记一个子表达式的开始和结束位置。子表达式可以获取供以后使用。要匹配这些字符,请使用 ( 和 )
-
*****:匹配前面的子表达式零次或多次。要匹配 * 字符,请使用 *
-
+:匹配前面的子表达式一次或多次。要匹配 + 字符,请使用 +
-
. :匹配除换行符 \n 之外的任何单字符。要匹配 . ,请使用 .
-
[:标记一个中括号表达式的开始。要匹配 [,请使用 [
-
?:匹配前面的子表达式零次或一次,或指明一个非贪婪限定符。要匹配 ? 字符,请使用 ?
-
\:将下一个字符标记为或特殊字符、或原义字符、或向后引用、或八进制转义符。例如, 'n' 匹配字符 'n'。'\n' 匹配换行符。序列 '\' 匹配 "",而 '(' 则匹配 "("
-
^:匹配输入字符串的开始位置,除非在方括号表达式中使用,当该符号在方括号表达式中使用时,表示不接受该方括号表达式中的字符集合。要匹配 ^ 字符本身,请使用 ^
-
{:标记限定符表达式的开始。要匹配 {,请使用 {
-
|:指明两项之间的一个选择。要匹配 |,请使用 |
-
-
限定符:
-
-
限定符用来指定正则表达式的一个给定组件必须要出现多少次才能满足匹配
-
:匹配前面的子表达式零次或多次。例如,zo* 能匹配 "z" 以及 "zoo"。**** 等价于 {0,}
-
+:匹配前面的子表达式一次或多次。例如,zo+ 能匹配 "zo" 以及 "zoo",但不能匹配 "z"。+ 等价于 {1,}
-
?:匹配前面的子表达式零次或一次。例如,do(es)? 可以匹配 "do" 、 "does"、 "doxy" 中的 "do" 。? 等价于 {0,1}
-
{n}:n 是一个非负整数。匹配确定的 n 次。例如,o{2} 不能匹配 "Bob" 中的 o,但是能匹配 "food" 中的两个 o
-
{n,}:n 是一个非负整数。至少匹配n 次。例如,o{2,} 不能匹配 "Bob" 中的 o,但能匹配 "foooood" 中的所有 o。o{1,} 等价于 o+。o{0,} 则等价于 o*
-
{n,m}:m 和 n 均为非负整数,其中 n <= m。最少匹配 n 次且最多匹配 m 次。例如,o{1,3} 将匹配 "fooooood" 中的前三个 o。o{0,1} 等价于 o?。请注意在逗号和两个数之间不能有空格
-
-
定位符:
-
- 定位符能够将正则表达式固定到行首或行尾。它们还能够创建这样的正则表达式,这些正则表达式出现在一个单词内、在一个单词的开头或者一个单词的结尾。
- ^:匹配输入字符串开始的位置
- ****:匹配输入字符串结尾的位置
- \b:匹配一个单词边界,即字与空格间的位置
- \B:非单词边界匹配
- 注意:不能将限定符与定位符一起使用。由于在紧靠换行或者单词边界的前面或后面不能有一个以上位置,因此不允许诸如 ^* 之类的表达式
-
选择:
-
- 用圆括号 () 将所有选择项括起来,相邻的选择项之间用 | 分隔
修饰符/标记
-
标记也称为修饰符,正则表达式的标记用于指定额外的匹配策略。
标记不写在正则表达式里,标记位于表达式之外,格式如下
/pattern/flags -
i:ignore - 不区分大小写,
-
g:global - 全局匹配
-
m:multiline - 多行匹配
-
s:特殊字符圆点 . 中包含换行符 \n
元字符(重要)
下表包含了元字符的完整列表以及它们在正则表达式上下文中的行为
| 字符 | 描述 |
|---|---|
| \ | 将下一个字符标记为一个特殊字符、或一个原义字符、或一个 向后引用、或一个八进制转义符。例如,'n' 匹配字符 "n"。'\n' 匹配一个换行符。序列 '\' 匹配 "" 而 "(" 则匹配 "("。 |
| ^ | 匹配输入字符串的开始位置。如果设置了 RegExp 对象的 Multiline 属性,^ 也匹配 '\n' 或 '\r' 之后的位置。 |
| ** 也匹配 '\n' 或 '\r' 之前的位置。 | |
| ***** | 匹配前面的子表达式零次或多次。例如,zo* 能匹配 "z" 以及 "zoo"。* 等价于{0,}。 |
| + | 匹配前面的子表达式一次或多次。例如,'zo+' 能匹配 "zo" 以及 "zoo",但不能匹配 "z"。+ 等价于 {1,}。 |
| ? | 匹配前面的子表达式零次或一次。例如,"do(es)?" 可以匹配 "do" 或 "does" 。? 等价于 {0,1}。 |
| {n} | n 是一个非负整数。匹配确定的 n 次。例如,'o{2}' 不能匹配 "Bob" 中的 'o',但是能匹配 "food" 中的两个 o。 |
| {n,} | n 是一个非负整数。至少匹配n 次。例如,'o{2,}' 不能匹配 "Bob" 中的 'o',但能匹配 "foooood" 中的所有 o。'o{1,}' 等价于 'o+'。'o{0,}' 则等价于 'o*'。 |
| {n,m} | m 和 n 均为非负整数,其中n <= m。最少匹配 n 次且最多匹配 m 次。例如,"o{1,3}" 将匹配 "fooooood" 中的前三个 o。'o{0,1}' 等价于 'o?'。请注意在逗号和两个数之间不能有空格。 |
| ? | 当该字符紧跟在任何一个其他限制符 (*, +, ?, {n}, {n,}, {n,m}) 后面时,匹配模式是非贪婪的。非贪婪模式尽可能少的匹配所搜索的字符串,而默认的贪婪模式则尽可能多的匹配所搜索的字符串。例如,对于字符串 "oooo",'o+?' 将匹配单个 "o",而 'o+' 将匹配所有 'o'。 |
| . | 匹配除换行符(\n、\r)之外的任何单个字符。要匹配包括 '\n' 在内的任何字符,请使用像"(.|\n)"的模式。 |
| (pattern) | 匹配 pattern 并获取这一匹配。所获取的匹配可以从产生的 Matches 集合得到,在VBScript 中使用 SubMatches 集合,在JScript 中则使用 9 属性。要匹配圆括号字符,请使用 '(' 或 ')'。 |
| (?:pattern) | 匹配 pattern 但不获取匹配结果,也就是说这是一个非获取匹配,不进行存储供以后使用。这在使用 "或" 字符 (|) 来组合一个模式的各个部分是很有用。例如, 'industr(?:y|ies) 就是一个比 'industry|industries' 更简略的表达式。 |
| (?=pattern) | 正向肯定预查(look ahead positive assert),在任何匹配pattern的字符串开始处匹配查找字符串。这是一个非获取匹配,也就是说,该匹配不需要获取供以后使用。例如,"Windows(?=95|98|NT|2000)"能匹配"Windows2000"中的"Windows",但不能匹配"Windows3.1"中的"Windows"。预查不消耗字符,也就是说,在一个匹配发生后,在最后一次匹配之后立即开始下一次匹配的搜索,而不是从包含预查的字符之后开始。 |
| (?!pattern) | 正向否定预查(negative assert),在任何不匹配pattern的字符串开始处匹配查找字符串。这是一个非获取匹配,也就是说,该匹配不需要获取供以后使用。例如"Windows(?!95|98|NT|2000)"能匹配"Windows3.1"中的"Windows",但不能匹配"Windows2000"中的"Windows"。预查不消耗字符,也就是说,在一个匹配发生后,在最后一次匹配之后立即开始下一次匹配的搜索,而不是从包含预查的字符之后开始。 |
| (?<=pattern) | 反向(look behind)肯定预查,与正向肯定预查类似,只是方向相反。例如,"(?<=95|98|NT|2000)Windows"能匹配"2000Windows"中的"Windows",但不能匹配"3.1Windows"中的"Windows"。 |
| (?<!pattern) | 反向否定预查,与正向否定预查类似,只是方向相反。例如"(?<!95|98|NT|2000)Windows"能匹配"3.1Windows"中的"Windows",但不能匹配"2000Windows"中的"Windows"。 |
| x|y | 匹配 x 或 y。例如,'z|food' 能匹配 "z" 或 "food"。'(z|f)ood' 则匹配 "zood" 或 "food"。 |
| [xyz] | 字符集合。匹配所包含的任意一个字符。例如, '[abc]' 可以匹配 "plain" 中的 'a'。 |
| 3 | 负值字符集合。匹配未包含的任意字符。例如, '[ ^abc]' 可以匹配 "plain" 中的'p'、'l'、'i'、'n'。 |
| [a-z] | 字符范围。匹配指定范围内的任意字符。例如,'[a-z]' 可以匹配 'a' 到 'z' 范围内的任意小写字母字符。 |
| 4 | 负值字符范围。匹配任何不在指定范围内的任意字符。例如,'[ ^a-z]' 可以匹配任何不在 'a' 到 'z' 范围内的任意字符。 |
| \b | 匹配一个单词边界,也就是指单词和空格间的位置。例如, 'er\b' 可以匹配"never" 中的 'er',但不能匹配 "verb" 中的 'er'。 |
| \B | 匹配非单词边界。'er\B' 能匹配 "verb" 中的 'er',但不能匹配 "never" 中的 'er'。 |
| \cx | 匹配由 x 指明的控制字符。例如, \cM 匹配一个 Control-M 或回车符。x 的值必须为 A-Z 或 a-z 之一。否则,将 c 视为一个原义的 'c' 字符。 |
| \d | 匹配一个数字字符。等价于 [0-9]。 |
| \D | 匹配一个非数字字符。等价于 [ ^0-9]。 |
| \f | 匹配一个换页符。等价于 \x0c 和 \cL。 |
| \n | 匹配一个换行符。等价于 \x0a 和 \cJ。 |
| \r | 匹配一个回车符。等价于 \x0d 和 \cM。 |
| \s | 匹配任何空白字符,包括空格、制表符、换页符等等。等价于 [ \f\n\r\t\v]。 |
| \S | 匹配任何非空白字符。等价于 [ ^ \f\n\r\t\v]。 |
| \t | 匹配一个制表符。等价于 \x09 和 \cI。 |
| \v | 匹配一个垂直制表符。等价于 \x0b 和 \cK。 |
| \w | 匹配字母、数字、下划线。等价于'[A-Za-z0-9_]'。 |
| \W | 匹配非字母、数字、下划线。等价于 '[ ^A-Za-z0-9_]'。 |
| \xn | 匹配 n,其中 n 为十六进制转义值。十六进制转义值必须为确定的两个数字长。例如,'\x41' 匹配 "A"。'\x041' 则等价于 '\x04' & "1"。正则表达式中可以使用 ASCII 编码。 |
| \num | 匹配 num,其中 num 是一个正整数。对所获取的匹配的引用。例如,'(.)\1' 匹配两个连续的相同字符。 |
| \n | 标识一个八进制转义值或一个向后引用。如果 \n 之前至少 n 个获取的子表达式,则 n 为向后引用。否则,如果 n 为八进制数字 (0-7),则 n 为一个八进制转义值。 |
| \nm | 标识一个八进制转义值或一个向后引用。如果 \nm 之前至少有 nm 个获得子表达式,则 nm 为向后引用。如果 \nm 之前至少有 n 个获取,则 n 为一个后跟文字 m 的向后引用。如果前面的条件都不满足,若 n 和 m 均为八进制数字 (0-7),则 \nm 将匹配八进制转义值 nm。 |
| \nml | 如果 n 为八进制数字 (0-3),且 m 和 l 均为八进制数字 (0-7),则匹配八进制转义值 nml。 |
| \un | 匹配 n,其中 n 是一个用四个十六进制数字表示的 Unicode 字符。例如, \u00A9 匹配版权符号 (?)。 |

运算符优先级
下表从最高到最低说明了各种正则表达式运算符的优先级顺序
| 运算符 | 描述 |
|---|---|
| \ | 转义符 |
| (), (?:), (?=), [] | 圆括号和方括号 |
| *, +, ?, {n}, {n,}, {n,m} | 限定符 |
| ^, , \任何元字符、任何字符 | 定位点和序列(即:位置和顺序) |
| | | 替换,"或"操作字符具有高于替换运算符的优先级,使得"m|food"匹配"m"或"food"。若要匹配"mood"或"food",请使用括号创建子表达式,从而产生"(m|f)ood"。 |
匹配规则
基本模式匹配
一切从最基本的开始。模式,是正则表达式最基本的元素,它们是一组描述字符串特征的字符。模式可以很简单,由普通的字符串组成,也可以非常复杂,往往用特殊的字符表示一个范围内的字符、重复出现,或表示上下文。例如:
^once
这个模式包含一个特殊的字符 ^,表示该模式只匹配那些以 once 开头的字符串。例如该模式与字符串 "once upon a time" 匹配,与 "There once was a man from NewYork" 不匹配。正如如 ^ 符号表示开头一样,**** 符号用来匹配那些以给定模式结尾的字符串。
bucket$
这个模式与 "Who kept all of this cash in a bucket" 匹配,与 "buckets" 不匹配。字符 ^ 和 **
只匹配字符串 **"bucket"**。如果一个模式不包括 **^** 和 **$**,那么它与任何包含该模式的字符串匹配。例如模式:
once
与字符串
There once was a man from NewYork Who kept all of his cash in a bucket.
是匹配的。
在该模式中的字母 **(o-n-c-e)** 是字面的字符,也就是说,他们表示该字母本身,数字也是一样的。其他一些稍微复杂的字符,如标点符号和白字符(空格、制表符等),要用到转义序列。所有的转义序列都用反斜杠 **\\** 打头。制表符的转义序列是 **\t**。所以如果我们要检测一个字符串是否以制表符开头,可以用这个模式:
^\t
类似的,用 **\n** 表示**"新行"**,**\r** 表示回车。其他的特殊符号,可以用在前面加上反斜杠,如反斜杠本身用 **\\\\** 表示,句号 **.** 用 **\\.** 表示,以此类推。
#### 字符簇
在 INTERNET 的程序中,正则表达式通常用来验证用户的输入。当用户提交一个 FORM 以后,要判断输入的电话号码、地址、EMAIL 地址、信用卡号码等是否有效,用普通的基于字面的字符是不够的。
所以要用一种更自由的描述我们要的模式的办法,它就是字符簇。要建立一个表示所有元音字符的字符簇,就把所有的元音字符放在一个方括号里:
[AaEeIiOoUu]
这个模式与任何元音字符匹配,但只能表示一个字符。用连字号可以表示一个字符的范围,如:
[a-z] // 匹配所有的小写字母 [A-Z] // 匹配所有的大写字母 [a-zA-Z] // 匹配所有的字母 [0-9] // 匹配所有的数字 [0-9.-] // 匹配所有的数字,句号和减号 [ \f\r\t\n] // 匹配所有的白字符
同样的,这些也只表示一个字符,这是一个非常重要的。如果要匹配一个由一个小写字母和一位数字组成的字符串,比如 "z2"、"t6" 或 "g7",但不是 "ab2"、"r2d3" 或 "b52" 的话,用这个模式:
^[a-z][0-9]
尽管 **[a-z]** 代表 26 个字母的范围,但在这里它只能与第一个字符是小写字母的字符串匹配。
前面曾经提到^表示字符串的开头,但它还有另外一个含义。当在一组方括号里使用 **^** 时,它表示"**非**"或"**排除**"的意思,常常用来剔除某个字符。还用前面的例子,我们要求第一个字符不能是数字:
^[^0-9][0-9]
这个模式与 "&5"、"g7"及"-2" 是匹配的,但与 "12"、"66" 是不匹配的。下面是几个排除特定字符的例子:
4 //除了小写字母以外的所有字符 5 //除了()(/)(^)之外的所有字符 6 //除了双引号(")和单引号(')之外的所有字符
特殊字符 **.**(点,句号)在正则表达式中用来表示除了"新行"之外的所有字符。所以模式 **^.5$** 与任何两个字符的、以数字5结尾和以其他非"新行"字符开头的字符串匹配。模式 **.** 可以匹配任何字符串,**换行符(\n、\r)除外**。
PHP的正则表达式有一些内置的通用字符簇,列表如下:
| 字符簇 | 描述 |
| :----------- | :---------------------------------- |
| [[:alpha:]] | 任何字母 |
| [[:digit:]] | 任何数字 |
| [[:alnum:]] | 任何字母和数字 |
| [[:space:]] | 任何空白字符 |
| [[:upper:]] | 任何大写字母 |
| [[:lower:]] | 任何小写字母 |
| [[:punct:]] | 任何标点符号 |
| [[:xdigit:]] | 任何16进制的数字,相当于[0-9a-fA-F] |
#### 确定重复出现
到现在为止,你已经知道如何去匹配一个字母或数字,但更多的情况下,可能要匹配一个单词或一组数字。一个单词有若干个字母组成,一组数字有若干个单数组成。跟在字符或字符簇后面的花括号({})用来确定前面的内容的重复出现的次数。
| 字符簇 | 描述 |
| :--------------- | :------------------------------ |
| ^[a-zA-Z_]$ | 所有的字母和下划线 |
| ^[[:alpha:]]{3}$ | 所有的3个字母的单词 |
| ^a$ | 字母a |
| ^a{4}$ | aaaa |
| ^a{2,4}$ | aa,aaa或aaaa |
| ^a{1,3}$ | a,aa或aaa |
| ^a{2,}$ | 包含多于两个a的字符串 |
| ^a{2,} | 如:aardvark和aaab,但apple不行 |
| a{2,} | 如:baad和aaa,但Nantucket不行 |
| \t{2} | 两个制表符 |
| .{2} | 所有的两个字符 |
这些例子描述了花括号的三种不同的用法。一个数字 **{x}** 的意思是**前面的字符或字符簇只出现x次** ;一个数字加逗号 **{x,}** 的意思是**前面的内容出现x或更多的次数** ;两个数字用逗号分隔的数字 **{x,y}** 表示 **前面的内容至少出现x次,但不超过y次**。我们可以把模式扩展到更多的单词或数字:
^[a-zA-Z0-9_]{1,} // 所有的正整数 ^-{0,1}[0-9]{1,} // 所有的浮点数
最后一个例子不太好理解,是吗?这么看吧:以一个可选的负号 (**[-]?**) 开头 (**^**)、跟着1个或更多的数字(**[0-9]+**)、和一个小数点(**\.**)再跟上1个或多个数字**([0-9]+**),并且后面没有其他任何东西(**$**)。下面你将知道能够使用的更为简单的方法。
特殊字符 **?** 与 **{0,1}** 是相等的,它们都代表着: **0个或1个前面的内容** 或 **前面的内容是可选的** 。所以刚才的例子可以简化为:
^-?[0-9]{1,}.?[0-9]{1,} // 所有包含一个以上的字母、数字或下划线的字符串 ^[1-9][0-9]* // 所有的正整数 ^-?[0-9]+ // 所有的整数 ^[-]?[0-9]+(.[0-9]+)? // 所有的浮点数
当然这并不能从技术上降低正则表达式的复杂性,但可以使它们更容易阅读
## python re模块
```python
import re
re.match函数
re.match 尝试从字符串的起始位置匹配一个模式,如果不是起始位置匹配成功的话,match() 就返回 none
re.match(pattern, string, flags=0)
函数参数说明:
| 参数 | 描述 |
|---|---|
| pattern | 匹配的正则表达式 |
| string | 要匹配的字符串。 |
| flags | 标志位,用于控制正则表达式的匹配方式,如:是否区分大小写,多行匹配等等。参见:[正则表达式基础修饰符] |
匹配成功 re.match 方法返回一个匹配的对象,否则返回 None。
我们可以使用 group(num) 或 groups() 匹配对象函数来获取匹配表达式。
| 匹配对象方法 | 描述 |
|---|---|
| group(num=0) | 匹配的整个表达式的字符串,group() 可以一次输入多个组号,在这种情况下它将返回一个包含那些组所对应值的元组。 |
| groups() | 返回一个包含所有小组字符串的元组,从 1 到 所含的小组号。 |
import re
### match 实例
print(re.match('www', 'www.runoob.com').span()) # 在起始位置匹配
print(re.match('com', 'www.runoob.com')) # 不在起始位置匹配
# 输出为
# (0, 3)
# None
### group 实例
line = "Cats are smarter than dogs"
matchObj = re.match( r'(.*) are (.*?) .*', line, re.M|re.I)
if matchObj:
print "matchObj.group() : ", matchObj.group()
print "matchObj.group(1) : ", matchObj.group(1)
print "matchObj.group(2) : ", matchObj.group(2)
else:
print "No match!!"
# 输出为
# matchObj.group() : Cats are smarter than dogs
# matchObj.group(1) : Cats
# matchObj.group(2) : smarter
re.research 函数
re.search 扫描整个字符串并返回第一个成功的匹配
函数语法:
re.search(pattern, string, flags=0)
函数参数说明:
| 参数 | 描述 |
|---|---|
| pattern | 匹配的正则表达式 |
| string | 要匹配的字符串。 |
| flags | 标志位,用于控制正则表达式的匹配方式,如:是否区分大小写,多行匹配等等。 |
匹配成功re.search方法返回一个匹配的对象,否则返回None。
我们可以使用group(num) 或 groups() 匹配对象函数来获取匹配表达式。
| 匹配对象方法 | 描述 |
|---|---|
| group(num=0) | 匹配的整个表达式的字符串,group() 可以一次输入多个组号,在这种情况下它将返回一个包含那些组所对应值的元组。 |
| groups() | 返回一个包含所有小组字符串的元组,从 1 到 所含的小组号。 |
import re
# search实例
print(re.search('www', 'www.runoob.com').span()) # 在起始位置匹配
print(re.search('com', 'www.runoob.com').span()) # 不在起始位置匹配
# 输出为
# (0, 3)
# (11, 14)
# group实例
line = "Cats are smarter than dogs";
searchObj = re.search( r'(.*) are (.*?) .*', line, re.M|re.I)
if searchObj:
print "searchObj.group() : ", searchObj.group()
print "searchObj.group(1) : ", searchObj.group(1)
print "searchObj.group(2) : ", searchObj.group(2)
else:
print "Nothing found!!"
# 输出为
# searchObj.group() : Cats are smarter than dogs
# searchObj.group(1) : Cats
# searchObj.group(2) : smarter
re.match与re.search的区别
re.match只匹配字符串的开始,如果字符串开始不符合正则表达式,则匹配失败,函数返回None;而re.search匹配整个字符串,直到找到一个匹配
re.sub 检索和替换
语法:
re.sub(pattern, repl, string, count=0, flags=0)
参数:
- pattern : 正则中的模式字符串。
- repl : 替换的字符串,也可为一个函数。
- string : 要被查找替换的原始字符串。
- count : 模式匹配后替换的最大次数,默认 0 表示替换所有的匹配。
import re
phone = "2004-959-559 # 这是一个国外电话号码"
# 删除字符串中的 Python注释
num = re.sub(r'#.*', "", phone)
print "电话号码是: ", num
# 删除非数字(-)的字符串
num = re.sub(r'\D', "", phone)
print "电话号码是 : ", num
# 输出为
# 电话号码是: 2004-959-559
# 电话号码是 : 2004959559
repl 参数可以是一个函数
import re
# 将匹配的数字乘以 2
def double(matched):
value = int(matched.group('value'))
return str(value * 2)
s = 'A23G4HFD567'
print(re.sub('(?P<value>\d+)', double, s))
# 输出为
# A46G8HFD1134
re.compile
compile 函数用于编译正则表达式,生成一个正则表达式( Pattern )对象,供 match() 和 search() 这两个函数使用,语法格式为
re.compile(pattern[, flags])
参数:
- pattern : 一个字符串形式的正则表达式
- flags : 可选,表示匹配模式,比如忽略大小写,多行模式等,具体参数为:
- re.I 忽略大小写
- re.L 表示特殊字符集 \w, \W, \b, \B, \s, \S 依赖于当前环境
- re.M 多行模式
- re.S 即为 . 并且包括换行符在内的任意字符(. 不包括换行符)
- re.U 表示特殊字符集 \w, \W, \b, \B, \d, \D, \s, \S 依赖于 Unicode 字符属性数据库
- re.X 为了增加可读性,忽略空格和 # 后面的注释
>>>import re
>>> pattern = re.compile(r'\d+') # 用于匹配至少一个数字
>>> m = pattern.match('one12twothree34four') # 查找头部,没有匹配
>>> print m
None
>>> m = pattern.match('one12twothree34four', 2, 10) # 从'e'的位置开始匹配,没有匹配
>>> print m
None
>>> m = pattern.match('one12twothree34four', 3, 10) # 从'1'的位置开始匹配,正好匹配
>>> print m # 返回一个 Match 对象
<_sre.SRE_Match object at 0x10a42aac0>
>>> m.group(0) # 可省略 0
'12'
>>> m.start(0) # 可省略 0
3
>>> m.end(0) # 可省略 0
5
>>> m.span(0) # 可省略 0
(3, 5)
findall
在字符串中找到正则表达式所匹配的所有子串,并返回一个列表,如果有多个匹配模式,则返回元组列表,如果没有找到匹配的,则返回空列表
注意: match 和 search 是匹配一次 findall 匹配所有
参数:
- string : 待匹配的字符串
- pos : 可选参数,指定字符串的起始位置,默认为 0
- endpos : 可选参数,指定字符串的结束位置,默认为字符串的长度
import re
pattern = re.compile(r'\d+') # 查找数字
result1 = pattern.findall('runoob 123 google 456')
result2 = pattern.findall('run88oob123google456', 0, 10)
print(result1)
print(result2)
# 输出
# ['123', '456']
# ['88', '12']
pytorch
创建网络的一种快捷方法:Sequential
net = torch.nn.Sequential(
torch.nn.Linear(STATE_SIZE, HIDDEN_SIZE),
torch.nn.ReLU(),
torch.nn.Linear(HIDDEN_SIZE, ACTION_SIZE),
)
2.1 构造张量的函数
torch.tensor() torch.zeros(), torch.zeros_like() torch.ones(), torch.ones_like() torch.full(), torch.full_like() 全填充为指定值 torch.empty(), torch.empty_like() torch.eye() torch.arange(), torch.range(), torch.linspace() torch.logspace() 等比 torch.rand(), torch.rand_like() 标准均匀 torch.randn(), torch.randn_like(), torch.normal() 标准正态 torch.randint(), torch.randint_like() torch.bernoulli() 两点分布 torch.multinomial() torch.randperm() {0,1,2,3...,n-1}的随机排列
2.2 重排张量元素
以下三种不会改变张量的实际位置(浅拷贝)
- reshape()
- squeeze():消除张量中大小为 的维度,
t.squeeze() - unsqueeze():添加一个大小为 的维度,
t.unsqueeze(dim=2)
2.3 张量扩展和拼接
- repeat()
- cat():两个参数,第一个是要拼接的张量的列表,第二个是延哪一个维度
- stack():同上,不同在于 stack 要求拼接的张量大小完全一样,延一个新的维度拼接
2.4 求解优化问题
- 在构造用做自变量的 torch.Tensor 类实例时,应将参数 requires_grad 设置为 True
- 调用张量类实例的成员方法 backward() 可以求偏导,调用完后,自变量的属性 grad 就储存了偏导的数值
from math import pi
import torch
x = torch.tensor([ pi/3 , pi/6 ], requires_grad=True)
f = -((x.cos()**2).sum)**2
print(f'value = {f}')
f.backward()
print(f'grad = {x.grad}')
优化算法与torch.optim包
在梯度下降时,先调用优化器实例的方法 zero_grad() 清空优化器在上次迭代中储存的数据,然后调用 torch.tensor 类实例的方法 backward() 求梯度,最后使用优化器的方法 step() 更新自变量的值
optimizer.zero_grad()
f.backward()
optimizer.step()
使用 torch.optim.SGD 梯度下降的一个实例
from math import pi
import torch
import torch.optim
x = torch.tensor([ pi/3 , pi/6 ], requires_grad=True)
optimizer = torch.optim.SGD([x,], lr=0.1 ,momentum=0)
for step in range(11):
if step:
optimizer.zero_grad()
f.backward()
optimizer.step()
f = -((x.cos()**2).sum)**2
print(f'step {step}: x = {x.tolist()}, f(x) = {f}')
torch.nn子包与损失类
torch.nn.Module 类及其子类可有以下用途
- 表示一个神经网络.如:torch.nn.Sequential 类可以表示一个前馈神经网络
- 表示神经网络的一个层:如 torch.nn.Linear 线性层,torch.nn.ReLU 激活层
- 表示损失:torch.nn.MSELoss,torh.nn.CrossEntropyLoss 等
激活层中逐元素激活分为以下三类
- S 型激活:Sigmoid,Softsign,Tanh,Hardtanh,ReLU6
- 单侧激活:ReLU,LeakyReLU,PReLU,RReLU,Threshold,ELU,SELU,Softplus,LogSigmoid
- 褶皱激活:Hardshrinkage,Softshrinkage,Tanhshrinkage
非逐元素激活
- Softmax,Softmax2d,LogSoftmax
torch.nn 里的损失类都是 torch.nn.Module 类的子类
criterion = torch.nn.MSELoss()
pred = torch.arange(5, requires_grad=True)
y = torch.ones(5)
loss = criterion(pred, y)
loss.backward()
训练集、验证集与训练集
训练集用来计算参数,验证集来判定欠拟合或过拟合,测试机来评价最终结果

| 欠拟合 | 过拟合 | |
|---|---|---|
| 泛化差错主要来源 | 偏差差错 (bias) | 方差差错 (variance) |
| 模型复杂度 | 过低 | 过高 |
| 学习曲线和验证曲线特征 | 收敛到比较大的差错值 | 两条曲线之间差别大 |
| 解决方案 | 增加模型复杂度 | 减小模型复杂度或增大训练集 |
2.5 标准化
- 批标准化( batch normalization ):对同一通道使用相同的均值和方差进行归一化,更适用于特征提取这样的应用
- 实例标准化( instance normalization ):对同一通道使用不同的均值和方差进行归一化,更适用于生成数据这样的应用
| 标准化操作类型 | 维度 | 标准化类 | 输入输出张量维度 | 适用网络 |
|---|---|---|---|---|
| 批标准化 | 1 | torch.nn.BatchNorm1d | 前馈神经网络 | |
| 批标准化 | 2 | torch.nn.BatchNorm2d | 前馈神经网络 | |
| 批标准化 | 3 | torch.nn.BatchNorm3d | 前馈神经网络 | |
| 实例标准化 | 1 | torch.nn.InstanceNorm1d | 前馈神经网络 | |
| 实例标准化 | 2 | torch.nn.InstanceNorm2d | 前馈神经网络 | |
| 实例标准化 | 3 | torch.nn.InstanceNorm3d | 前馈神经网络 | |
| 层标准化 | 不限 | torch.nn.LayerNorm | 前馈神经网络 |
2.6 网络权重初始化
pytorch 中完成权重初始化需要 torch.nn.init 子包和 torch.nn.Module 类成员方法 apply().
| 函数名 | 元素分布 | 分布参数确定方法 |
|---|---|---|
| torch.nn.init.uniform_() | 均匀分布 | 传入表示最小值的参数 a (默认为 0 )和表示最大值的参数 b (默认为 1 ) |
| torch.nn.init.normal_() | 正态分布 | 传入表示均值的参数 mean (默认为 0 )和表示方差的参数 std (默认为 1 ) |
| torch.nn.init.constant_() | 常量 | 传入常量 vaL |
| torch.nn.init.xavier_uniform_() | 均匀分布 | 均值为 0 ,标准差 根据输入的张量大小和增益函数 gain 计算得到 |
| torch.nn.init.xavier_uniform_() | 均匀分布 | 均值为 0 ,标准差 根据输入的张量大小和增益函数 gain 计算得到 |
apply() 方法有一个参数,参数是一个 python 函数,这个函数的参数必须是 torch.nn.Module 类.
import torch.nn.init as init
def weights_init(m):
init.xavier_normal_(m.weight)
init.constant_(m.bias, 0)
2.7 卷积神经网络
对一维卷积,设 为输入张量,大小为 , 为卷积核,大小为 ,输出张量为 ,大小为 ,则有
补全 (pad) 运算

在补零后 (前后各补 ,) ,相应的张量维度为
核的膨胀(dilate),基本互相关中,每个权重连续对应着输入张量中的元素,此时可认为膨胀系数为 ,膨胀前后核大小关系为 图 8-4 给出了膨胀系数 的例子.膨胀前,核的大小为 ,膨胀后,

步幅(stride),基本互相关中,卷积核每次相对输入张量 向右移动一个元素的位置并得到一个输出张量,一共 个输出.将此输出大小记为 视为可以认为基本互相关操作的步幅 ,如果考虑更大步幅,则有
补全、步幅、膨胀可以综合使用.综合前文,输入大小 ,输出大小 ,核张量大小 ,两侧分别补全数 和 ,步幅 ,膨胀系数 之间的关系满足 将以上几式综合起来,可以得到
torch.nn 里的卷积层
| 运算类型 | 运算维度 | torch.nn.Module子类 | 类实例输入张量的大小 | 类实例输出张量的大小 |
|---|---|---|---|---|
| 互相关 | 1 | torch.nn.Conv1d | ||
| 互相关 | 2 | torch.nn.Conv2d | ||
| 互相关 | 3 | torch.nn.Conv3d |
为样本的计数, 表示数据的通道数,即一条数据有几个 维张量.卷积层的输出通道数表示最多支持的特征个数.因为每个通道使用相同的卷积核计算,每个卷积核只能提取一种特征.
conv = torch.nn.Conv2d(16, 33, kernel_size={3, 5}, stride={2, 1}, padding={4, 2}, dilation={3, 1})
inputs = torch.rand(20, 16, 50, 100) #20条样本,16个通道,每个通道大小为 50*100
outputs = conv(inputs)
outputs.size()
张量的池化
池化 (pooling),核不需要权重
- 最大池化(max pool):输出张量的每个元素都是若干个输入张量的最大值
- 平均池化(average pool):输出元素由若干个输入元素求平均得到
- 池化( pool):计算输入元素组合的 范数


以下为不带“自适应”(adaptive)的版本,带自适应只需在 MaxPool1d 前加上 Adaptive,此时不能设置补全数等,他会自动帮你计算
| 运算类型 | 运算维度 | torch.nn.Module子类 | 类实例输入张量的大小 | 类实例输出张量的大小 |
|---|---|---|---|---|
| 最大池化 | 1 | torch.nn.MaxPool1d | ||
| 最大池化 | 2 | torch.nn.MaxPool2d | ||
| 最大池化 | 3 | torch.nn.MaxPool3d | ||
| 平均池化 | 1 | torch.nn.AvgPool1d | ||
| 平均池化 | 2 | torch.nn.AvgPool2d | ||
| 平均池化 | 3 | torch.nn.AvgPool3d | ||
| 池化 | 1 | torch.nn.LPPool1d | ||
| 池化 | 2 | torch.nn.LPPool2d | ||
| 最大反池化 | 1 | torch.nn.MaxUnpool1d | ||
| 最大反池化 | 2 | torch.nn.MaxUnpool2d | ||
| 最大反池化 | 3 | torch.nn.MaxUnpool3d |
张量的上采样
张量的上采样(up-sample),将输入张量的每个维度大小扩展若干倍.
- 最邻近上采样( nearest up-sample ):按照一个比例因子( scale factor )将每个元素重复若干次
- 线性插值上采样( linearup-sample )
pytorch 中上采样用的是 torch.nn 的子包 torch.nn.Unsample 类.
| 运算类型 | 运算维度 | torch.nn.Unsample类实例构造参数 | 类实例输入张量的大小 | 类实例输出张量的大小 |
|---|---|---|---|---|
| 最邻近上采样 | 1 | mode='nearest'(默认值) | ||
| 最邻近上采样 | 2 | mode='nearest'(默认值) | ||
| 最邻近上采样 | 3 | mode='nearest'(默认值) | ||
| 线性上采样 | 1 | mode='linear' | ||
| 线性上采样 | 2 | mode='bilinear' | ||
| 线性上采样 | 3 | mode='trilinear' |
张量的补全运算
- 常数补全( constant pad ):输入张量前后补上常数
- 重复补全( replication pad ):用最边上的值补全
- 反射补全( reflection pad ):以边界为对称轴补全
| 运算类型 | 运算维度 | torch.nn.Module子类 | 类实例输入张量的大小 | 类实例输出张量的大小 |
|---|---|---|---|---|
| 常数补全 | 2 | torch.nn.ConstantPad2d | ||
| 重复补全 | 2 | torch.nn.ReplicationPad2d | ||
| 反射补全 | 2 | torch.nn.ReflectionPad2d | ||
| 反射补全 | 3 | torch.nn.Reflection3d |
inputs = torch.arange(12).view(1, 1, 3, 4)
pad = nn.ConstantPad2d(padding=[1, 1, 1, 1], value=-1)
pad = nn.Replication2d(padding=[1, 1, 1, 1])
pad = nn.Reflection2d(padding=[1, 1, 1, 1])
例如实现下图的卷积网络,可以参考的构建网络方法:

import torch.nn
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv0 = torch.nn.Conv2d(1, 64, kernel_size=3, padding=1)
self.relu1 = torch.nn.ReLU()
self.conv2 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.relu3 = torch.nn.ReLU()
self.pool4 = torch.nn.MaxPool2d(stride=2, kernel_size=2)
self.fc5 = torch.nn.Linear(128*14*14, 1024)
self.relu6 = torch.nn.ReLU()
self.drop7 = torch.nn.Dropout(p=0.5)
self.fc8 = torch.nn.Linear(1024, 10)
def forward(self, x):
x = self.conv0(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.relu3(x)
x = self.pool4(x)
x = x.view(-1, 128 * 14 * 14)
x = self.fc5(x)
x = self.relu6(x)
x = self.drop7(x)
x = self.fc8(x)
return x
net = Net()
另外可用 sequential 方法
import torch.nn
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, 64, kernel_size=3, padding=1)
torch.nn.ReLU()
torch.nn.Conv2d(64, 128, kernel_size=3, padding=1)
torch.nn.ReLU()
torch.nn.MaxPool2d(stride=2, kernel_size=2))
self.dense = torch.nn.Sequential(
torch.nn.Linear(128*14*14, 1024)
torch.nn.ReLU()
torch.nn.Dropout(p=0.5)
torch.nn.Linear(1024, 10))
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 128 * 14 * 14)
x = self.dense(x)
return x
net = Net()
2.8 循环神经网络
TODO:循环神经网络
以下是 LSTM 示例
import torch.nn
class Net(torch.nn.Module):
def __init__(self, input_size, hidden_size):
super(Net, self).__init__()
self.rnn = torch.nn.LSTM(input_size, hidden_size)
self.fc = torch.nn.Linear(hidden_size, 1)
def forward(self, x):
x = x[:, :, None]
x, _ = self.rnn(x)
x = self.fc(x)
x = x[:, :, 0]
return x
net = Net(input_size=1, hidden_size=5)
2.9 生成对抗网络
- 生成网络( generative network ):一般一条随机输入是一个有多个元素的张量 ,张量 的取值空间称为“潜在空间”( latent space ),张量 的元素个数称为“潜在大小”( latent size ).生成网络 可以将这条潜在张量样本 映射为一条数据张量 .
- 鉴别网络( discriminative network ):对生成网络生成的数据进行判定.
以 记交叉熵损失函数
目的:训练鉴别网络 使得 训练生成网络使得
以下是CIFAR-10图像生成的实例
'''读取数据'''
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torchvision.utils import save_image
dataset = CIFAR10(root='./data', download=True,
transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
for batch_idx, data in enumerate(dataloader):
real_images, _ = data
batch_size = real_images.size(0)
print ('#{} has {} images.'.format(batch_idx, batch_size))
if batch_idx % 100 == 0:
path = './data/CIFAR10_shuffled_batch{:03d}.png'.format(batch_idx)
save_image(real_images, path, normalize=True)
'''生成网络与鉴别网络的搭建'''
import torch.nn as nn
# 搭建生成网络
latent_size = 64 # 潜在大小
n_channel = 3 # 输出通道数
n_g_feature = 64 # 生成网络隐藏层大小
gnet = nn.Sequential(
# 输入大小 = (64, 1, 1)
nn.ConvTranspose2d(latent_size, 4 * n_g_feature, kernel_size=4,
bias=False),
nn.BatchNorm2d(4 * n_g_feature),
nn.ReLU(),
# 大小 = (256, 4, 4)
nn.ConvTranspose2d(4 * n_g_feature, 2 * n_g_feature, kernel_size=4,
stride=2, padding=1, bias=False),
nn.BatchNorm2d(2 * n_g_feature),
nn.ReLU(),
# 大小 = (128, 8, 8)
nn.ConvTranspose2d(2 * n_g_feature, n_g_feature, kernel_size=4,
stride=2, padding=1, bias=False),
nn.BatchNorm2d(n_g_feature),
nn.ReLU(),
# 大小 = (64, 16, 16)
nn.ConvTranspose2d(n_g_feature, n_channel, kernel_size=4,
stride=2, padding=1),
nn.Sigmoid(),
# 图片大小 = (3, 32, 32)
)
print (gnet)
# 搭建鉴别网络
n_d_feature = 64 # 鉴别网络隐藏层大小
dnet = nn.Sequential(
# 图片大小 = (3, 32, 32)
nn.Conv2d(n_channel, n_d_feature, kernel_size=4,
stride=2, padding=1),
nn.LeakyReLU(0.2),
# 大小 = (64, 16, 16)
nn.Conv2d(n_d_feature, 2 * n_d_feature, kernel_size=4,
stride=2, padding=1, bias=False),
nn.BatchNorm2d(2 * n_d_feature),
nn.LeakyReLU(0.2),
# 大小 = (128, 8, 8)
nn.Conv2d(2 * n_d_feature, 4 * n_d_feature, kernel_size=4,
stride=2, padding=1, bias=False),
nn.BatchNorm2d(4 * n_d_feature),
nn.LeakyReLU(0.2),
# 大小 = (256, 4, 4)
nn.Conv2d(4 * n_d_feature, 1, kernel_size=4),
# 对数赔率张量大小 = (1, 1, 1)
)
print(dnet)
'''网络初始化'''
import torch.nn.init as init
def weights_init(m): # 用于初始化权重值的函数
if type(m) in [nn.ConvTranspose2d, nn.Conv2d]:
init.xavier_normal_(m.weight)
elif type(m) == nn.BatchNorm2d:
init.normal_(m.weight, 1.0, 0.02)
init.constant_(m.bias, 0)
gnet.apply(weights_init)
dnet.apply(weights_init)
'''训练并输出图片'''
import torch
import torch.optim
# 损失
criterion = nn.BCEWithLogitsLoss()
# 优化器
goptimizer = torch.optim.Adam(gnet.parameters(),
lr=0.0002, betas=(0.5, 0.999))
doptimizer = torch.optim.Adam(dnet.parameters(),
lr=0.0002, betas=(0.5, 0.999))
# 用于测试的固定噪声,用来查看相同的潜在张量在训练过程中生成图片的变换
batch_size = 64
fixed_noises = torch.randn(batch_size, latent_size, 1, 1)
# 训练过程
epoch_num = 10
for epoch in range(epoch_num):
for batch_idx, data in enumerate(dataloader):
# 载入本批次数据
real_images, _ = data
batch_size = real_images.size(0)
# 训练鉴别网络
labels = torch.ones(batch_size) # 真实数据对应标签为1
preds = dnet(real_images) # 对真实数据进行判别
outputs = preds.reshape(-1)
dloss_real = criterion(outputs, labels) # 真实数据的鉴别器损失
dmean_real = outputs.sigmoid().mean()
# 计算鉴别器将多少比例的真数据判定为真,仅用于输出显示
noises = torch.randn(batch_size, latent_size, 1, 1) # 潜在噪声
fake_images = gnet(noises) # 生成假数据
labels = torch.zeros(batch_size) # 假数据对应标签为0
fake = fake_images.detach()
# 使得梯度的计算不回溯到生成网络,可用于加快训练速度.删去此步结果不变
preds = dnet(fake) # 对假数据进行鉴别
outputs = preds.view(-1)
dloss_fake = criterion(outputs, labels) # 假数据的鉴别器损失
dmean_fake = outputs.sigmoid().mean()
# 计算鉴别器将多少比例的假数据判定为真,仅用于输出显示
dloss = dloss_real + dloss_fake # 总的鉴别器损失
dnet.zero_grad()
dloss.backward()
doptimizer.step()
# 训练生成网络
labels = torch.ones(batch_size)
# 生成网络希望所有生成的数据都被认为是真数据
preds = dnet(fake_images) # 把假数据通过鉴别网络
outputs = preds.view(-1)
gloss = criterion(outputs, labels) # 真数据看到的损失
gmean_fake = outputs.sigmoid().mean()
# 计算鉴别器将多少比例的假数据判定为真,仅用于输出显示
gnet.zero_grad()
gloss.backward()
goptimizer.step()
# 输出本步训练结果
print('[{}/{}]'.format(epoch, epoch_num) +
'[{}/{}]'.format(batch_idx, len(dataloader)) +
'鉴别网络损失:{:g} 生成网络损失:{:g}'.format(dloss, gloss) +
'真数据判真比例:{:g} 假数据判真比例:{:g}/{:g}'.format(
dmean_real, dmean_fake, gmean_fake))
if batch_idx % 100 == 0:
fake = gnet(fixed_noises) # 由固定潜在张量生成假数据
save_image(fake, # 保存假数据
'./data/images_epoch{:02d}_batch{:03d}.png'.format(
epoch, batch_idx))
从代码结构方面优化加速python
0. 代码优化原则
本文会介绍不少的 Python 代码加速运行的技巧。在深入代码优化细节之前,需要了解一些代码优化基本原则。
第一个基本原则是不要过早优化。很多人一开始写代码就奔着性能优化的目标,“让正确的程序更快要比让快速的程序正确容易得多”。因此,优化的前提是代码能正常工作。过早地进行优化可能会忽视对总体性能指标的把握,在得到全局结果前不要主次颠倒。
第二个基本原则是权衡优化的代价。优化是有代价的,想解决所有性能的问题是几乎不可能的。通常面临的选择是时间换空间或空间换时间。另外,开发代价也需要考虑。
第三个原则是不要优化那些无关紧要的部分。如果对代码的每一部分都去优化,这些修改会使代码难以阅读和理解。如果你的代码运行速度很慢,首先要找到代码运行慢的位置,通常是内部循环,专注于运行慢的地方进行优化。在其他地方,一点时间上的损失没有什么影响。
1. 避免全局变量
# 不推荐写法。代码耗时:26.8秒
import math
size = 10000
for x in range(size):
for y in range(size):
z = math.sqrt(x) + math.sqrt(y)
许多程序员刚开始会用 Python 语言写一些简单的脚本,当编写脚本时,通常习惯了直接将其写为全局变量,例如上面的代码。但是,由于全局变量和局部变量实现方式不同,定义在全局范围内的代码运行速度会比定义在函数中的慢不少。通过将脚本语句放入到函数中,通常可带来 15% - 30% 的速度提升。
# 推荐写法。代码耗时:20.6秒
import math
def main(): # 定义到函数中,以减少全部变量使用
size = 10000
for x in range(size):
for y in range(size):
z = math.sqrt(x) + math.sqrt(y)
main()
2. 避免.
2.1 避免模块和函数属性访问
# 不推荐写法。代码耗时:14.5秒
import math
def computeSqrt(size: int):
result = []
for i in range(size):
result.append(math.sqrt(i))
return result
def main():
size = 10000
for _ in range(size):
result = computeSqrt(size)
main()
每次使用.(属性访问操作符时)会触发特定的方法,如__getattribute__()和__getattr__(),这些方法会进行字典操作,因此会带来额外的时间开销。通过from import语句,可以消除属性访问。
# 第一次优化写法。代码耗时:10.9秒
from math import sqrt
def computeSqrt(size: int):
result = []
for i in range(size):
result.append(sqrt(i)) # 避免math.sqrt的使用
return result
def main():
size = 10000
for _ in range(size):
result = computeSqrt(size)
main()
在第 1 节中我们讲到,局部变量的查找会比全局变量更快,因此对于频繁访问的变量sqrt,通过将其改为局部变量可以加速运行。
# 第二次优化写法。代码耗时:9.9秒
import math
def computeSqrt(size: int):
result = []
sqrt = math.sqrt # 赋值给局部变量
for i in range(size):
result.append(sqrt(i)) # 避免math.sqrt的使用
return result
def main():
size = 10000
for _ in range(size):
result = computeSqrt(size)
main()
除了math.sqrt外,computeSqrt函数中还有.的存在,那就是调用list的append方法。通过将该方法赋值给一个局部变量,可以彻底消除computeSqrt函数中for循环内部的.使用。
# 推荐写法。代码耗时:7.9秒
import math
def computeSqrt(size: int):
result = []
append = result.append
sqrt = math.sqrt # 赋值给局部变量
for i in range(size):
append(sqrt(i)) # 避免 result.append 和 math.sqrt 的使用
return result
def main():
size = 10000
for _ in range(size):
result = computeSqrt(size)
main()
2.2 避免类内属性访问
# 不推荐写法。代码耗时:10.4秒
import math
from typing import List
class DemoClass:
def __init__(self, value: int):
self._value = value
def computeSqrt(self, size: int) -> List[float]:
result = []
append = result.append
sqrt = math.sqrt
for _ in range(size):
append(sqrt(self._value))
return result
def main():
size = 10000
for _ in range(size):
demo_instance = DemoClass(size)
result = demo_instance.computeSqrt(size)
main()
避免.的原则也适用于类内属性,访问self._value的速度会比访问一个局部变量更慢一些。通过将需要频繁访问的类内属性赋值给一个局部变量,可以提升代码运行速度。
# 推荐写法。代码耗时:8.0秒
import math
from typing import List
class DemoClass:
def __init__(self, value: int):
self._value = value
def computeSqrt(self, size: int) -> List[float]:
result = []
append = result.append
sqrt = math.sqrt
value = self._value
for _ in range(size):
append(sqrt(value)) # 避免 self._value 的使用
return result
def main():
size = 10000
for _ in range(size):
demo_instance = DemoClass(size)
demo_instance.computeSqrt(size)
main()
3. 避免不必要的抽象
# 不推荐写法,代码耗时:0.55秒
class DemoClass:
def __init__(self, value: int):
self.value = value
@property
def value(self) -> int:
return self._value
@value.setter
def value(self, x: int):
self._value = x
def main():
size = 1000000
for i in range(size):
demo_instance = DemoClass(size)
value = demo_instance.value
demo_instance.value = i
main()
任何时候当你使用额外的处理层(比如装饰器、属性访问、描述器)去包装代码时,都会让代码变慢。大部分情况下,需要重新进行审视使用属性访问器的定义是否有必要,使用getter/setter函数对属性进行访问通常是 C/C++ 程序员遗留下来的代码风格。如果真的没有必要,就使用简单属性。
# 推荐写法,代码耗时:0.33秒
class DemoClass:
def __init__(self, value: int):
self.value = value # 避免不必要的属性访问器
def main():
size = 1000000
for i in range(size):
demo_instance = DemoClass(size)
value = demo_instance.value
demo_instance.value = i
main()
4. 避免数据复制
4.1 避免无意义的数据复制
# 不推荐写法,代码耗时:6.5秒
def main():
size = 10000
for _ in range(size):
value = range(size)
value_list = [x for x in value]
square_list = [x * x for x in value_list]
main()
上面的代码中value_list完全没有必要,这会创建不必要的数据结构或复制。
# 推荐写法,代码耗时:4.8秒
def main():
size = 10000
for _ in range(size):
value = range(size)
square_list = [x * x for x in value] # 避免无意义的复制
main()
另外一种情况是对 Python 的数据共享机制过于偏执,并没有很好地理解或信任 Python 的内存模型,滥用 copy.deepcopy()之类的函数。通常在这些代码中是可以去掉复制操作的。
4.2 交换值时不使用中间变量
# 不推荐写法,代码耗时:0.07秒
def main():
size = 1000000
for _ in range(size):
a = 3
b = 5
temp = a
a = b
b = temp
main()
上面的代码在交换值时创建了一个临时变量temp,如果不借助中间变量,代码更为简洁、且运行速度更快。
# 推荐写法,代码耗时:0.06秒
def main():
size = 1000000
for _ in range(size):
a = 3
b = 5
a, b = b, a # 不借助中间变量
main()
4.3 字符串拼接用join而不是+
# 不推荐写法,代码耗时:2.6秒
import string
from typing import List
def concatString(string_list: List[str]) -> str:
result = ''
for str_i in string_list:
result += str_i
return result
def main():
string_list = list(string.ascii_letters * 100)
for _ in range(10000):
result = concatString(string_list)
main()
当使用a + b拼接字符串时,由于 Python 中字符串是不可变对象,其会申请一块内存空间,将a和b分别复制到该新申请的内存空间中。
# 推荐写法,代码耗时:0.3秒
import string
from typing import List
def concatString(string_list: List[str]) -> str:
return ''.join(string_list) # 使用 join 而不是 +
def main():
string_list = list(string.ascii_letters * 100)
for _ in range(10000):
result = concatString(string_list)
main()
5. 利用if条件的短路特性
# 不推荐写法,代码耗时:0.05秒
from typing import List
def concatString(string_list: List[str]) -> str:
abbreviations = {'cf.', 'e.g.', 'ex.', 'etc.', 'flg.', 'i.e.', 'Mr.', 'vs.'}
abbr_count = 0
result = ''
for str_i in string_list:
if str_i in abbreviations:
result += str_i
return result
def main():
for _ in range(10000):
string_list = ['Mr.', 'Hat', 'is', 'Chasing', 'the', 'black', 'cat', '.']
result = concatString(string_list)
main()
if 条件的短路特性是指对if a and b这样的语句, 当a为False时将直接返回,不再计算b;对于if a or b这样的语句,当a为True时将直接返回,不再计算b。因此, 为了节约运行时间,对于or语句,应该将值为True可能性比较高的变量写在or前,而and应该推后。
# 推荐写法,代码耗时:0.03秒
from typing import List
def concatString(string_list: List[str]) -> str:
abbreviations = {'cf.', 'e.g.', 'ex.', 'etc.', 'flg.', 'i.e.', 'Mr.', 'vs.'}
abbr_count = 0
result = ''
for str_i in string_list:
if str_i[-1] == '.' and str_i in abbreviations: # 利用 if 条件的短路特性
result += str_i
return result
def main():
for _ in range(10000):
string_list = ['Mr.', 'Hat', 'is', 'Chasing', 'the', 'black', 'cat', '.']
result = concatString(string_list)
main()
6. 循环优化
6.1 用for循环代替while循环
# 不推荐写法。代码耗时:6.7秒
def computeSum(size: int) -> int:
sum_ = 0
i = 0
while i < size:
sum_ += i
i += 1
return sum_
def main():
size = 10000
for _ in range(size):
sum_ = computeSum(size)
main()
Python 的for循环比while循环快不少。
# 推荐写法。代码耗时:4.3秒
def computeSum(size: int) -> int:
sum_ = 0
for i in range(size): # for 循环代替 while 循环
sum_ += i
return sum_
def main():
size = 10000
for _ in range(size):
sum_ = computeSum(size)
main()
6.2 使用隐式for循环代替显式for循环
针对上面的例子,更进一步可以用隐式for循环来替代显式for循环
# 推荐写法。代码耗时:1.7秒
def computeSum(size: int) -> int:
return sum(range(size)) # 隐式 for 循环代替显式 for 循环
def main():
size = 10000
for _ in range(size):
sum = computeSum(size)
main()
6.3 减少内层for循环的计算
# 不推荐写法。代码耗时:12.8秒
import math
def main():
size = 10000
sqrt = math.sqrt
for x in range(size):
for y in range(size):
z = sqrt(x) + sqrt(y)
main()
上面的代码中sqrt(x)位于内侧for循环, 每次训练过程中都会重新计算一次,增加了时间开销。
# 推荐写法。代码耗时:7.0秒
import math
def main():
size = 10000
sqrt = math.sqrt
for x in range(size):
sqrt_x = sqrt(x) # 减少内层 for 循环的计算
for y in range(size):
z = sqrt_x + sqrt(y)
main()
7. 使用numba.jit
我们沿用上面介绍过的例子,在此基础上使用numba.jit。 numba可以将 Python 函数 JIT 编译为机器码执行,大大提高代码运行速度。关于numba的更多信息见下面的主页:
http://numba.pydata.org/numba.pydata.org/
# 推荐写法。代码耗时:0.62秒
import numba
@numba.jit
def computeSum(size: float) -> int:
sum = 0
for i in range(size):
sum += i
return sum
def main():
size = 10000
for _ in range(size):
sum = computeSum(size)
main()
8. 选择合适的数据结构
Python 内置的数据结构如str, tuple, list, set, dict底层都是 C 实现的,速度非常快,自己实现新的数据结构想在性能上达到内置的速度几乎是不可能的。
list类似于 C++ 中的std::vector,是一种动态数组。其会预分配一定内存空间,当预分配的内存空间用完,又继续向其中添加元素时,会申请一块更大的内存空间,然后将原有的所有元素都复制过去,之后销毁之前的内存空间,再插入新元素。删除元素时操作类似,当已使用内存空间比预分配内存空间的一半还少时,会另外申请一块小内存,做一次元素复制,之后销毁原有大内存空间。因此,如果有频繁的新增、删除操作,新增、删除的元素数量又很多时,list的效率不高。此时,应该考虑使用collections.deque。collections.deque是双端队列,同时具备栈和队列的特性,能够在两端进行 O(1) 复杂度的插入和删除操作。
list的查找操作也非常耗时。当需要在list频繁查找某些元素,或频繁有序访问这些元素时,可以使用bisect维护list对象有序并在其中进行二分查找,提升查找的效率。
另外一个常见需求是查找极小值或极大值,此时可以使用heapq模块将list转化为一个堆,使得获取最小值的时间复杂度是 O(1) 。
下面的网页给出了常用的 Python 数据结构的各项操作的时间复杂度: TimeComplexity - Python WiKi
python之禅
print写法
name = 'ROSE'
country = 'China'
age = 20
print('hi, my name is {}. im from {}, and im {}'.format(name,country,age))
最简单写法
print(f'hi, my name is {name}, im from {country}, and im {age+1}')
for 循环时使用 enumerate 可返回两个参数,前一个是 index ,第二个是对应参数
for idx,step in enumerate(range(10))
@staticmethod
静态方法, 不强制要求传递参数
@classmethod
类方法, 不需要实例化, 不需要self参数, 但第一个参数需要是表示自身类的cls参数, 可以用来调用类的属性, 类的方法, 实例化对象等
### 类特殊方法
class Test():
def __init__():
pass
def __enter__():
'''使用with语句创建示例时会自动运行此方法'''
pass
‘’‘
with Test() as t:
pass
’‘’
def __exit__():
'''使用with语句创建实例, 在结束时自动调用该方法'''
pass
def __str__():
'''可print(实例)'''
return ‘我是Test类’
def __setattr__, __getattr__, __getattribute__, __delattr__:
'''对属性进行操作'''
pass
def __call__():
'''能让把实例化对象直接当做函数来调用'''
print(1)
'''
a = Test()
in: a()
out: 1
'''
def __contains__, __len__():
'''类作为容器'''
pass
# HDFStore`
with pd.HDFStore('iv_hv.h5') as store:
c = store.keys()
'''matplotlib
':' 点虚线
'-' 实线
'--' 破折线
'-.' 点划线
添加水平垂直线
plt.axhline(y=0,ls=":",c="yellow") 水平直线
plt.axvline(x=4,ls="-",c="green") 垂直直线
'''

自动发邮件
import smtplib
from smtplib import SMTP_SSL
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.header import Header
from email.mime.application import MIMEApplication # 用于添加附件
host_server = 'smtp.qq.com' # qq邮箱smtp服务器
sender_qq = '359582058@qq.com' # 发件人邮箱
pwd = 'bglhxfrujynobhda'qq
pwd = 'EGGZASTFLHVGCBRU'163
receiver = '13918949838@163.com'
mail_title = 'Python自动发送邮件' # 邮件标题
# 邮件正文内容
mail_content = "您好"
msg = MIMEMultipart()
msg["Subject"] = Header(mail_title, 'utf-8')
msg["From"] = sender_qq
msg["To"] = Header("测试邮箱", "utf-8")
msg.attach(MIMEText(mail_content, 'html'))
attachment = MIMEApplication(open('复权.xlsx', 'rb').read())
attachment["Content-Type"] = 'application/octet-stream'
# 给附件重命名
basename = "复权.xlsx"
attachment.add_header('Content-Disposition', 'attachment',
filename=('utf-8', '', basename)) # 注意:此处basename要转换为gbk编码,否则中文会有乱码。
msg.attach(attachment)
try:
smtp = SMTP_SSL(host_server) # ssl登录连接到邮件服务器
smtp.set_debuglevel(1) # 0是关闭,1是开启debug
smtp.ehlo(host_server) # 跟服务器打招呼,告诉它我们准备连接,最好加上这行代码
smtp.login(sender_qq, pwd)
smtp.sendmail(sender_qq, receiver, msg.as_string())
smtp.quit()
print("邮件发送成功")
except smtplib.SMTPException:
print("无法发送邮件")
py23com
from win32com.client import makepy
makepy.main() # 跳出窗口, 创建静态代理static proxy
win32com.client.constant.__d
Rust相关
教程来自于Rust语言圣经
Rust基础知识
Rust基础知识
在线运行测试
fn main(){ println!("Hello, world!") }
安装rust(202302)
macOS
$ curl --proto '=https' --tlsv1.2 https://sh.rustup.rs -sSf | sh
显示
Rust is installed now. Great!
则安装成功!
安装C语言编译器
$ xcode-select --install
Windows
先安装Microsoft C++ Build Tools, 勾选安装C++环境. 然后将Rust所需的msvc命令行程序手动添加到环境变量中, 其位于%Visual Studio 安装位置%\VC\Tools\MSVC\%version%\bin\Hostx64\x64下.
在RUSTUP-INIT下载系统对应的版本,
PS C:\Users\Hehongyuan> rustup-init.exe
......
Current installation options:
default host triple: x86_64-pc-windows-msvc
default toolchain: stable (default)
profile: default
modify PATH variable: yes
1) Proceed with installation (default)
2) Customize installation
3) Cancel installation
cargo
上面的命令使用 cargo new 创建一个项目,项目名是 world_hello, 该项目的结构和配置文件都是由 cargo 生成,意味着我们的项目被 cargo 所管理.
下面来看看创建的项目结构:
$ tree
.
├── .git
├── .gitignore
├── Cargo.toml
└── src
└── main.rs
运行项目
有两种方式可以运行项目:
-
cargo run -
手动编译和运行项目
首先来看看第一种方式,在之前创建的 world_hello 目录下运行:
$ cargo run
Compiling world_hello v0.1.0 (/Users/sunfei/development/rust/world_hello)
Finished dev [unoptimized + debuginfo] target(s) in 0.43s
Running `target/debug/world_hello`
Hello, world!
好了,你已经看到程序的输出:"Hello, world"。
如果你安装的 Rust 的 host triple 是 x86_64-pc-windows-msvc 并确认 Rust 已经正确安装,但在终端上运行上述命令时,出现类似如下的错误摘要 linking with `link.exe` failed: exit code: 1181,请使用 Visual Studio Installer 安装 Windows SDK。
上述代码,cargo run 首先对项目进行编译,然后再运行,因此它实际上等同于运行了两个指令,下面我们手动试一下编译和运行项目:
编译
$ cargo build
Finished dev [unoptimized + debuginfo] target(s) in 0.00s
运行
$ ./target/debug/world_hello
Hello, world!
行云流水,但谈不上一气呵成。在调用的时候,路径 ./target/debug/world_hello 中有一个明晃晃的 debug 字段,没错我们运行的是 debug 模式,在这种模式下,代码的编译速度会非常快,可是福兮祸所依,运行速度就慢了. 原因是,在 debug 模式下,Rust 编译器不会做任何的优化,只为了尽快的编译完成,让开发流程更加顺畅。
想要高性能的代码怎么办? 简单,添加 --release 来编译:
cargo run --releasecargo build --release
试着运行一下我们高性能的 release 程序:
$ ./target/release/world_hello
Hello, world!
cargo check
当项目大了后,cargo run 和 cargo build 不可避免的会变慢,那么有没有更快的方式来验证代码的正确性呢?
cargo check 是我们在代码开发过程中最常用的命令,它的作用很简单:快速的检查一下代码能否编译通过。因此该命令速度会非常快,能节省大量的编译时间。
$ cargo check
Checking world_hello v0.1.0 (/Users/sunfei/development/rust/world_hello)
Finished dev [unoptimized + debuginfo] target(s) in 0.06s
Rust 虽然编译速度还行,但是还是不能与 Go 语言相提并论,因为 Rust 需要做很多复杂的编译优化和语言特性解析,甚至连如何优化编译速度都成了一门学问: 优化编译速度
Cargo.toml 和 Cargo.lock
Cargo.toml 和 Cargo.lock 是 cargo 的核心文件,它的所有活动均基于此二者。
-
Cargo.toml是cargo特有的项目数据描述文件。它存储了项目的所有元配置信息,如果 Rust 开发者希望 Rust 项目能够按照期望的方式进行构建、测试和运行,那么,必须按照合理的方式构建Cargo.toml。 -
Cargo.lock文件是cargo工具根据同一项目的toml文件生成的项目依赖详细清单,因此我们一般不用修改
什么情况下该把
Cargo.lock上传到 git 仓库里?很简单,当你的项目是一个可运行的程序时,就上传Cargo.lock,如果是一个依赖库项目,那么请把它添加到.gitignore中
现在用 VSCode 打开上面创建的"世界,你好"项目,然后进入根目录的 Cargo.toml 文件,可以看到该文件包含不少信息:
package 配置段落
package 中记录了项目的描述信息,典型的如下:
[package]
name = "world_hello"
version = "0.1.0"
edition = "2021"
name 字段定义了项目名称,version 字段定义当前版本,新项目默认是 0.1.0,edition 字段定义了使用的 Rust 大版本
定义项目依赖
使用 cargo 工具的最大优势就在于,能够对该项目的各种依赖项进行方便、统一和灵活的管理。
在 Cargo.toml 中,主要通过各种依赖段落来描述该项目的各种依赖项:
- 基于 Rust 官方仓库
crates.io,通过版本说明来描述 - 基于项目源代码的 git 仓库地址,通过 URL 来描述
- 基于本地项目的绝对路径或者相对路径,通过类 Unix 模式的路径来描述
这三种形式具体写法如下:
[dependencies]
rand = "0.3"
hammer = { version = "0.5.0"}
color = { git = "https://github.com/bjz/color-rs" }
geometry = { path = "crates/geometry" }
变量绑定与解构
变量绑定
在其它语言中,我们用 var a = "hello world" 的方式给 a 赋值,也就是把等式右边的 "hello world" 字符串赋值给变量 a ,而在 Rust 中,我们这样写: let a = "hello world" ,同时给这个过程起了另一个名字:变量绑定。
为何不用赋值而用绑定呢(其实你也可以称之为赋值,但是绑定的含义更清晰准确)?这里就涉及 Rust 最核心的原则——所有权,简单来讲,任何内存对象都是有主人的,而且一般情况下完全属于它的主人,绑定就是把这个对象绑定给一个变量.
变量可变性
Rust 的变量在默认情况下是不可变的。前文提到,这是 Rust 团队为我们精心设计的语言特性之一,让我们编写的代码更安全,性能也更好。当然你可以通过 mut 关键字让变量变为可变的,让设计更灵活。
如果变量 a 不可变,那么一旦为它绑定值,就不能再修改 a:
fn main() { let x = 5; println!("The value of x is: {}", x); x = 6; println!("The value of x is: {}", x); }
在 Rust 中,可变性很简单,只要在变量名前加一个 mut 即可, 而且这种显式的声明方式还会给后来人传达这样的信息:嗯,这个变量在后面代码部分会发生改变。
fn main() { let mut x = 5; println!("The value of x is: {}", x); x = 6; println!("The value of x is: {}", x); }
如果创建了一个变量却不在任何地方使用它,Rust 通常会给你一个警告,因为这可能会是个 BUG。但是有时创建一个不会被使用的变量是有用的,比如你正在设计原型或刚刚开始一个项目。这时你希望告诉 Rust 不要警告未使用的变量,为此可以用下划线作为变量名的开头:
fn main() { let _x = 5; let y = 10; }
变量解构
let 表达式不仅仅用于变量的绑定,还能进行复杂变量的解构:从一个相对复杂的变量中,匹配出该变量的一部分内容:
fn main() { let (a, mut b): (bool,bool) = (true, false); // a = true,不可变; b = false,可变 println!("a = {:?}, b = {:?}", a, b); b = true; assert_eq!(a, b); }
解构式赋值
在Rust 1.59版本后,我们可以在赋值语句的左式中使用元组、切片和结构体模式了。
struct Struct { e: i32 } fn main() { let (a, b, c, d, e); (a, b) = (1, 2); // _ 代表匹配一个值,但是我们不关心具体的值是什么,因此没有使用一个变量名而是使用了 _ [c, .., d, _] = [1, 2, 3, 4, 5]; Struct { e, .. } = Struct { e: 5 }; assert_eq!([1, 2, 1, 4, 5], [a, b, c, d, e]); }
这种使用方式跟之前的 let 保持了一致性,但是 let 会重新绑定,而这里仅仅是对之前绑定的变量进行再赋值。
变量和常量之间的差异
变量的值不能更改可能让你想起其他另一个很多语言都有的编程概念:常量(constant)。与不可变变量一样,常量也是绑定到一个常量名且不允许更改的值,但是常量和变量之间存在一些差异:
- 常量不允许使用
mut。常量不仅仅默认不可变,而且自始至终不可变,因为常量在编译完成后,已经确定它的值。 - 常量使用
const关键字而不是let关键字来声明,并且值的类型必须标注。
常量可以在任意作用域内声明,包括全局作用域,在声明的作用域内,常量在程序运行的整个过程中都有效。
变量遮蔽(shadowing)
Rust 允许声明相同的变量名,在后面声明的变量会遮蔽掉前面声明的,如下所示:
fn main() { let x = 5; // 在main函数的作用域内对之前的x进行遮蔽 let x = x + 1; { // 在当前的花括号作用域内,对之前的x进行遮蔽 let x = x * 2; println!("The value of x in the inner scope is: {}", x); } println!("The value of x is: {}", x); }
这个程序首先将数值 5 绑定到 x,然后通过重复使用 let x = 来遮蔽之前的 x,并取原来的值加上 1,所以 x 的值变成了 6。第三个 let 语句同样遮蔽前面的 x,取之前的值并乘上 2,得到的 x 最终值为 12。
这和 mut 变量的使用是不同的,第二个 let 生成了完全不同的新变量,两个变量只是恰好拥有同样的名称,涉及一次内存对象的再分配
,而 mut 声明的变量,可以修改同一个内存地址上的值,并不会发生内存对象的再分配,性能要更好.
变量遮蔽的用处在于,如果你在某个作用域内无需再使用之前的变量(在被遮蔽后,无法再访问到之前的同名变量),就可以重复的使用变量名字,而不用绞尽脑汁去想更多的名字
所有权,引用与借用
所有权
所有的程序都必须和计算机内存打交道,如何从内存中申请空间来存放程序的运行内容,如何在不需要的时候释放这些空间,成了重中之重,在计算机语言不断演变过程中,出现了三种流派:
- 垃圾回收机制(GC),在程序运行时不断寻找不再使用的内存,典型代表:Java、Go
- 手动管理内存的分配和释放, 在程序中,通过函数调用的方式来申请和释放内存,典型代表:C++
- 通过所有权来管理内存,编译器在编译时会根据一系列规则进行检查
其中 Rust 选择了第三种,最妙的是,这种检查只发生在编译期,因此对于程序运行期,不会有任何性能上的损失。
一段不安全的代码
先来看看一段来自 C 语言的糟糕代码:
int* foo() {
int a; // 变量a的作用域开始
a = 100;
char *c = "xyz"; // 变量c的作用域开始
return &a;
} // 变量a和c的作用域结束
这段代码虽然可以编译通过,但是其实非常糟糕,变量 a 和 c 都是局部变量,函数结束后将局部变量 a 的地址返回,但局部变量 a 存在栈中,在离开作用域后,a 所申请的栈上内存都会被系统回收,从而造成了 悬空指针(Dangling Pointer) 的问题。这是一个非常典型的内存安全问题,虽然编译可以通过,但是运行的时候会出现错误, 很多编程语言都存在。再来看变量 c,c 的值是常量字符串,存储于常量区,可能这个函数我们只调用了一次,也可能我们不再会使用这个字符串,但 "xyz" 只有当整个程序结束后系统才能回收这片内存。
栈(Stack)与堆(Heap)
栈和堆的核心目标就是为程序在运行时提供可供使用的内存空间。
栈
栈按照顺序存储值并以相反顺序取出值,这也被称作后进先出。增加数据叫做进栈,移出数据则叫做出栈。
因为上述的实现方式,栈中的所有数据都必须占用已知且固定大小的内存空间,假设数据大小是未知的,那么在取出数据时,你将无法取到你想要的数据。
堆
与栈不同,对于大小未知或者可能变化的数据,我们需要将它存储在堆上。
当向堆上放入数据时,需要请求一定大小的内存空间。操作系统在堆的某处找到一块足够大的空位,把它标记为已使用,并返回一个表示该位置地址的指针, 该过程被称为在堆上分配内存,有时简称为 “分配”(allocating)。
接着,该指针会被推入栈中,因为指针的大小是已知且固定的,在后续使用过程中,你将通过栈中的指针,来获取数据在堆上的实际内存位置,进而访问该数据。
性能区别
写入方面:入栈比在堆上分配内存要快,因为入栈时操作系统无需分配新的空间,只需要将新数据放入栈顶即可。相比之下,在堆上分配内存则需要更多的工作,这是因为操作系统必须首先找到一块足够存放数据的内存空间,接着做一些记录为下一次分配做准备。
读取方面:得益于 CPU 高速缓存,使得处理器可以减少对内存的访问,高速缓存和内存的访问速度差异在 10 倍以上!栈数据往往可以直接存储在 CPU 高速缓存中,而堆数据只能存储在内存中。访问堆上的数据比访问栈上的数据慢,因为必须先访问栈再通过栈上的指针来访问内存。
因此,处理器处理和分配在栈上的数据会比在堆上的数据更加高效。
所有权与堆栈
当你的代码调用一个函数时,传递给函数的参数(包括可能指向堆上数据的指针和函数的局部变量)依次被压入栈中,当函数调用结束时,这些值将被从栈中按照相反的顺序依次移除。
因为堆上的数据缺乏组织,因此跟踪这些数据何时分配和释放是非常重要的,否则堆上的数据将产生内存泄漏 —— 这些数据将永远无法被回收。这就是 Rust 所有权系统为我们提供的强大保障。
对于其他很多编程语言,你确实无需理解堆栈的原理,但是在 Rust 中,明白堆栈的原理,对于我们理解所有权的工作原理会有很大的帮助。
所有权原则
理解了堆栈,接下来看一下关于所有权的规则,首先请谨记以下规则:
- Rust 中每一个值都被一个变量所拥有,该变量被称为值的所有者
- 一个值同时只能被一个变量所拥有,或者说一个值只能拥有一个所有者
- 当所有者(变量)离开作用域范围时,这个值将被丢弃(drop)
变量作用域
作用域是一个变量在程序中有效的范围, 假如有这样一个变量:
#![allow(unused)] fn main() { let s = "hello"; }
变量 s 绑定到了一个字符串字面值,该字符串字面值是硬编码到程序代码中的。s 变量从声明的点开始直到当前作用域的结束都是有效的:
#![allow(unused)] fn main() { { // s 在这里无效,它尚未声明 let s = "hello"; // 从此处起,s 是有效的 // 使用 s } // 此作用域已结束,s不再有效 }
简而言之,s 从创建伊始就开始有效,然后有效期持续到它离开作用域为止,可以看出,就作用域来说,Rust 语言跟其他编程语言没有区别。
简单介绍 String 类型
我们已经见过字符串字面值 let s ="hello",s 是被硬编码进程序里的字符串值(类型为 &str )。字符串字面值是很方便的,但是它并不适用于所有场景。原因有二:
- 字符串字面值是不可变的,因为被硬编码到程序代码中
- 并非所有字符串的值都能在编写代码时得知
例如,字符串是需要程序运行时,通过用户动态输入然后存储在内存中的,这种情况,字符串字面值就完全无用武之地。 为此,Rust 为我们提供动态字符串类型: String, 该类型被分配到堆上,因此可以动态伸缩,也就能存储在编译时大小未知的文本。
可以使用下面的方法基于字符串字面量来创建 String 类型:
#![allow(unused)] fn main() { let s = String::from("hello"); }
:: 是一种调用操作符,这里表示调用 String 中的 from 方法,因为 String 存储在堆上是动态的,你可以这样修改它:
#![allow(unused)] fn main() { let mut s = String::from("hello"); s.push_str(", world!"); // push_str() 在字符串后追加字面值 println!("{}", s); // 将打印 `hello, world!` }
变量绑定背后的数据交互
转移所有权
先来看一段代码:
#![allow(unused)] fn main() { let x = 5; let y = x; }
代码背后的逻辑很简单, 将 5 绑定到变量 x;接着拷贝 x 的值赋给 y,最终 x 和 y 都等于 5,因为整数是 Rust 基本数据类型,是固定大小的简单值,因此这两个值都是通过自动拷贝的方式来赋值的,都被存在栈中,完全无需在堆上分配内存。
这种拷贝不消耗性能吗?实际上,这种栈上的数据足够简单,而且拷贝非常非常快,只需要复制一个整数大小(i32,4 个字节)的内存即可,因此在这种情况下,拷贝的速度远比在堆上创建内存来得快的多。实际上,上一章我们讲到的 Rust 基本类型都是通过自动拷贝的方式来赋值的,就像上面代码一样。
然后再来看一段代码:
#![allow(unused)] fn main() { let s1 = String::from("hello"); let s2 = s1; }
对于基本类型(存储在栈上),Rust 会自动拷贝,但是 String 不是基本类型,而且是存储在堆上的,因此不能自动拷贝。
实际上, String 类型是一个复杂类型,由存储在栈中的堆指针、字符串长度、字符串容量共同组成,其中堆指针是最重要的,它指向了真实存储字符串内容的堆内存,至于长度和容量,如果你有 Go 语言的经验,这里就很好理解:容量是堆内存分配空间的大小,长度是目前已经使用的大小。
假定一个值可以拥有两个所有者,会发生什么呢?当变量离开作用域后,Rust 会自动调用 drop 函数并清理变量的堆内存。不过由于两个 String 变量指向了同一位置。这就有了一个问题:当 s1 和 s2 离开作用域,它们都会尝试释放相同的内存。这是一个叫做 二次释放(double free) 的错误,也是之前提到过的内存安全性 BUG 之一。两次释放(相同)内存会导致内存污染,它可能会导致潜在的安全漏洞。
因此,Rust 实际这样解决问题:当 s1 赋予 s2 后,Rust 认为 s1 不再有效,因此也无需在 s1 离开作用域后 drop 任何东西,这就是把所有权从 s1 转移给了 s2,s1 在被赋予 s2 后就马上失效了。
再来看看,在所有权转移后再来使用旧的所有者,会发生什么:
#![allow(unused)] fn main() { let s1 = String::from("hello"); let s2 = s1; println!("{}, world!", s1); }
回头看之前的规则
- Rust 中每一个值都被一个变量所拥有,该变量被称为值的所有者
- 一个值同时只能被一个变量所拥有,或者说一个值只能拥有一个所有者
- 当所有者(变量)离开作用域范围时,这个值将被丢弃(drop)
如果你在其他语言中听说过术语 浅拷贝(shallow copy) 和 深拷贝(deep copy),那么拷贝指针、长度和容量而不拷贝数据听起来就像浅拷贝,但是又因为 Rust 同时使第一个变量 s1 无效了,因此这个操作被称为 移动(move),而不是浅拷贝。上面的例子可以解读为 s1 被移动到了 s2 中。那么具体发生了什么,用一张图简单说明:
这样就解决了我们之前的问题,s1 不再指向任何数据,只有 s2 是有效的,当 s2 离开作用域,它就会释放内存。 相信此刻,你应该明白了,为什么 Rust 称呼 let a = b 为变量绑定了吧?
再来看一段代码:
fn main() { let x: &str = "hello, world"; let y = x; println!("{},{}",x,y); }
这段代码,大家觉得会否报错?如果参考之前的 String 所有权转移的例子,那这段代码也应该报错才是,但是实际上呢?
这段代码和之前的 String 有一个本质上的区别:在 String 的例子中 s1 持有了通过String::from("hello") 创建的值的所有权,而这个例子中,x 只是引用了存储在二进制中的字符串 "hello, world",并没有持有所有权。
因此 let y = x 中,仅仅是对该引用进行了拷贝,此时 y 和 x 都引用了同一个字符串。学习了 "引用与借用" 后,自然而言就会理解。
克隆(深拷贝)
首先,Rust 永远也不会自动创建数据的 “深拷贝”。因此,任何自动的复制都不是深拷贝,可以被认为对运行时性能影响较小。
如果我们确实需要深度复制 String 中堆上的数据,而不仅仅是栈上的数据,可以使用一个叫做 clone 的方法。
#![allow(unused)] fn main() { let s1 = String::from("hello"); let s2 = s1.clone(); println!("s1 = {}, s2 = {}", s1, s2); }
这段代码能够正常运行,因此说明 s2 确实完整的复制了 s1 的数据。
如果代码性能无关紧要,例如初始化程序时,或者在某段时间只会执行一次时,你可以使用 clone 来简化编程。但是对于执行较为频繁的代码(热点路径),使用 clone 会极大的降低程序性能,需要小心使用!
拷贝(浅拷贝)
浅拷贝只发生在栈上,因此性能很高,在日常编程中,浅拷贝无处不在。
再回到之前看过的例子:
#![allow(unused)] fn main() { let x = 5; let y = x; println!("x = {}, y = {}", x, y); }
但这段代码似乎与我们刚刚学到的内容相矛盾:没有调用 clone,不过依然实现了类似深拷贝的效果 —— 没有报所有权的错误。
原因是像整型这样的基本类型在编译时是已知大小的,会被存储在栈上,所以拷贝其实际的值是快速的。这意味着没有理由在创建变量 y 后使 x 无效(x、y 都仍然有效)。换句话说,这里没有深浅拷贝的区别,因此这里调用 clone 并不会与通常的浅拷贝有什么不同,我们可以不用管它(可以理解成在栈上做了深拷贝)。
Rust 有一个叫做 Copy 的特征,可以用在类似整型这样在栈中存储的类型。如果一个类型拥有 Copy 特征,一个旧的变量在被赋值给其他变量后仍然可用。
那么什么类型是可 Copy 的呢?可以查看给定类型的文档来确认,不过作为一个通用的规则: 任何基本类型的组合可以 Copy ,不需要分配内存或某种形式资源的类型是可以 Copy 的。如下是一些 Copy 的类型:
- 所有整数类型,比如
u32。 - 布尔类型,
bool,它的值是true和false。 - 所有浮点数类型,比如
f64。 - 字符类型,
char。 - 元组,当且仅当其包含的类型也都是
Copy的时候。比如,(i32, i32)是Copy的,但(i32, String)就不是。 - 不可变引用
&T,但是注意: 可变引用&mut T是不可以 Copy的
函数传值与返回
将值传递给函数,一样会发生 移动 或者 复制,就跟 let 语句一样,下面的代码展示了所有权、作用域的规则:
fn main() { let s = String::from("hello"); // s 进入作用域 takes_ownership(s); // s 的值移动到函数里 ... // ... 所以到这里不再有效 let x = 5; // x 进入作用域 makes_copy(x); // x 应该移动函数里, // 但 i32 是 Copy 的,所以在后面可继续使用 x } // 这里, x 先移出了作用域,然后是 s。但因为 s 的值已被移走, // 所以不会有特殊操作 fn takes_ownership(some_string: String) { // some_string 进入作用域 println!("{}", some_string); } // 这里,some_string 移出作用域并调用 `drop` 方法。占用的内存被释放 fn makes_copy(some_integer: i32) { // some_integer 进入作用域 println!("{}", some_integer); } // 这里,some_integer 移出作用域。不会有特殊操作
你可以尝试在 takes_ownership 之后,再使用 s,看看如何报错?例如添加一行 println!("在move进函数后继续使用s: {}",s);。
同样的,函数返回值也有所有权,例如:
fn main() { let s1 = gives_ownership(); // gives_ownership 将返回值 // 移给 s1 let s2 = String::from("hello"); // s2 进入作用域 let s3 = takes_and_gives_back(s2); // s2 被移动到 // takes_and_gives_back 中, // 它也将返回值移给 s3 } // 这里, s3 移出作用域并被丢弃。s2 也移出作用域,但已被移走, // 所以什么也不会发生。s1 移出作用域并被丢弃 fn gives_ownership() -> String { // gives_ownership 将返回值移动给 // 调用它的函数 let some_string = String::from("hello"); // some_string 进入作用域. some_string // 返回 some_string 并移出给调用的函数 } // takes_and_gives_back 将传入字符串并返回该值 fn takes_and_gives_back(a_string: String) -> String { // a_string 进入作用域 a_string // 返回 a_string 并移出给调用的函数 }
所有权很强大,避免了内存的不安全性,但是也带来了一个新麻烦: 总是把一个值传来传去来使用它。 传入一个函数,很可能还要从该函数传出去,结果就是语言表达变得非常啰嗦,幸运的是,Rust 提供了新功能解决这个问题。
引用与借用
上节中提到,如果仅仅支持通过转移所有权的方式获取一个值,那会让程序变得复杂。 Rust 能否像其它编程语言一样,使用某个变量的指针或者引用呢?答案是可以。
Rust 通过 借用(Borrowing) 这个概念来达成上述的目的,获取变量的引用,称之为借用(borrowing)。
引用与解引用
常规引用是一个指针类型,指向了对象存储的内存地址。在下面代码中,我们创建一个 i32 值的引用 y,然后使用解引用运算符*来解出 y 所使用的值:
fn main() { let x = 5; let y = &x; assert_eq!(5, x); assert_eq!(5, *y); }
变量 x 存放了一个 i32 值 5。y 是 x 的一个引用。可以断言 x 等于 5。然而,如果希望对 y 的值做出断言,必须使用 *y 来解出引用所指向的值(也就是解引用)。一旦解引用了 y,就可以访问 y 所指向的整型值并可以与 5 做比较。
相反如果尝试编写 assert_eq!(5, y);,则会得到如下编译错误:
error[E0277]: can't compare `{integer}` with `&{integer}`
--> src/main.rs:6:5
|
6 | assert_eq!(5, y);
| ^^^^^^^^^^^^^^^^^ no implementation for `{integer} == &{integer}` // 无法比较整数类型和引用类型
|
= help: the trait `std::cmp::PartialEq<&{integer}>` is not implemented for
`{integer}`
不允许比较整数与引用,因为它们是不同的类型。必须使用解引用运算符解出引用所指向的值。
不可变引用
下面的代码,我们用 s1 的引用作为参数传递给 calculate_length 函数,而不是把 s1 的所有权转移给该函数:
fn main() { let s1 = String::from("hello"); let len = calculate_length(&s1); println!("The length of '{}' is {}.", s1, len); } fn calculate_length(s: &String) -> usize { s.len() }
能注意到两点:
- 无需像上章一样:先通过函数参数传入所有权,然后再通过函数返回来传出所有权,代码更加简洁
calculate_length的参数s类型从String变为&String
这里,& 符号即是引用,它们允许你使用值,但是不获取所有权,如图所示:

通过 &s1 语法,我们创建了一个指向 s1 的引用,但是并不拥有它。因为并不拥有这个值,当引用离开作用域后,其指向的值也不会被丢弃。
同理,函数 calculate_length 使用 & 来表明参数 s 的类型是一个引用:
#![allow(unused)] fn main() { fn calculate_length(s: &String) -> usize { // s 是对 String 的引用 s.len() } // 这里,s 离开了作用域。但因为它并不拥有引用值的所有权, // 所以什么也不会发生 }
因此光借用满足不了我们,如果尝试修改借用的变量呢?
fn main() { let s = String::from("hello"); change(&s); } fn change(some_string: &String) { some_string.push_str(", world"); }
很不幸,修改错了:
error[E0596]: cannot borrow `*some_string` as mutable, as it is behind a `&` reference
--> src/main.rs:8:5
|
7 | fn change(some_string: &String) {
| ------- help: consider changing this to be a mutable reference: `&mut String`
------- 帮助:考虑将该参数类型修改为可变的引用: `&mut String`
8 | some_string.push_str(", world");
| ^^^^^^^^^^^ `some_string` is a `&` reference, so the data it refers to cannot be borrowed as mutable
`some_string`是一个`&`类型的引用,因此它指向的数据无法进行修改
正如变量默认不可变一样,引用指向的值默认也是不可变的,没事,来一起看看如何解决这个问题。
可变引用
只需要一个小调整,即可修复上面代码的错误:
fn main() { let mut s = String::from("hello"); change(&mut s); } fn change(some_string: &mut String) { some_string.push_str(", world"); }
首先,声明 s 是可变类型,其次创建一个可变的引用 &mut s 和接受可变引用参数 some_string: &mut String 的函数。
可变引用同时只能存在一个
不过可变引用并不是随心所欲、想用就用的,它有一个很大的限制: 同一作用域,特定数据只能有一个可变引用:
#![allow(unused)] fn main() { let mut s = String::from("hello"); let r1 = &mut s; let r2 = &mut s; println!("{}, {}", r1, r2); }
以上代码会报错:
error[E0499]: cannot borrow `s` as mutable more than once at a time 同一时间无法对 `s` 进行两次可变借用
--> src/main.rs:5:14
|
4 | let r1 = &mut s;
| ------ first mutable borrow occurs here 首个可变引用在这里借用
5 | let r2 = &mut s;
| ^^^^^^ second mutable borrow occurs here 第二个可变引用在这里借用
6 |
7 | println!("{}, {}", r1, r2);
| -- first borrow later used here 第一个借用在这里使用
这段代码出错的原因在于,第一个可变借用 r1 必须要持续到最后一次使用的位置 println!,在 r1 创建和最后一次使用之间,我们又尝试创建第二个可变借用 r2。
对于新手来说,这个特性绝对是一大拦路虎,也是新人们谈之色变的编译器 borrow checker 特性之一,不过各行各业都一样,限制往往是出于安全的考虑,Rust 也一样。
这种限制的好处就是使 Rust 在编译期就避免数据竞争,数据竞争可由以下行为造成:
- 两个或更多的指针同时访问同一数据
- 至少有一个指针被用来写入数据
- 没有同步数据访问的机制
数据竞争会导致未定义行为,这种行为很可能超出我们的预期,难以在运行时追踪,并且难以诊断和修复。而 Rust 避免了这种情况的发生,因为它甚至不会编译存在数据竞争的代码!
很多时候,大括号可以帮我们解决一些编译不通过的问题,通过手动限制变量的作用域:
#![allow(unused)] fn main() { let mut s = String::from("hello"); { let r1 = &mut s; } // r1 在这里离开了作用域,所以我们完全可以创建一个新的引用 let r2 = &mut s; }
可变引用与不可变引用不能同时存在
下面的代码会导致一个错误:
#![allow(unused)] fn main() { let mut s = String::from("hello"); let r1 = &s; // 没问题 let r2 = &s; // 没问题 let r3 = &mut s; // 大问题 println!("{}, {}, and {}", r1, r2, r3); }
错误如下:
error[E0502]: cannot borrow `s` as mutable because it is also borrowed as immutable
// 无法借用可变 `s` 因为它已经被借用了不可变
--> src/main.rs:6:14
|
4 | let r1 = &s; // 没问题
| -- immutable borrow occurs here 不可变借用发生在这里
5 | let r2 = &s; // 没问题
6 | let r3 = &mut s; // 大问题
| ^^^^^^ mutable borrow occurs here 可变借用发生在这里
7 |
8 | println!("{}, {}, and {}", r1, r2, r3);
| -- immutable borrow later used here 不可变借用在这里使用
注意,引用的作用域
s从创建开始,一直持续到它最后一次使用的地方,这个跟变量的作用域有所不同,变量的作用域从创建持续到某一个花括号}
NLL
对于这种编译器优化行为,Rust 专门起了一个名字 —— Non-Lexical Lifetimes(NLL),专门用于找到某个引用在作用域(})结束前就不再被使用的代码位置。
虽然这种借用错误有的时候会让我们很郁闷,但是你只要想想这是 Rust 提前帮你发现了潜在的 BUG,其实就开心了,虽然减慢了开发速度,但是从长期来看,大幅减少了后续开发和运维成本。
悬垂引用(Dangling References)
悬垂引用也叫做悬垂指针,意思为指针指向某个值后,这个值被释放掉了,而指针仍然存在,其指向的内存可能不存在任何值或已被其它变量重新使用。在 Rust 中编译器可以确保引用永远也不会变成悬垂状态:当你获取数据的引用后,编译器可以确保数据不会在引用结束前被释放,要想释放数据,必须先停止其引用的使用。
让我们尝试创建一个悬垂引用,Rust 会抛出一个编译时错误:
fn main() { let reference_to_nothing = dangle(); } fn dangle() -> &String { let s = String::from("hello"); &s }
这里是错误:
error[E0106]: missing lifetime specifier
--> src/main.rs:5:16
|
5 | fn dangle() -> &String {
| ^ expected named lifetime parameter
|
= help: this function's return type contains a borrowed value, but there is no value for it to be borrowed from
help: consider using the `'static` lifetime
|
5 | fn dangle() -> &'static String {
| ~~~~~~~~
仔细看看 dangle 代码的每一步到底发生了什么:
#![allow(unused)] fn main() { fn dangle() -> &String { // dangle 返回一个字符串的引用 let s = String::from("hello"); // s 是一个新字符串 &s // 返回字符串 s 的引用 } // 这里 s 离开作用域并被丢弃。其内存被释放。 // 危险! }
因为 s 是在 dangle 函数内创建的,当 dangle 的代码执行完毕后,s 将被释放, 但是此时我们又尝试去返回它的引用。这意味着这个引用会指向一个无效的 String,这可不对!
其中一个很好的解决方法是直接返回 String:
#![allow(unused)] fn main() { fn no_dangle() -> String { let s = String::from("hello"); s } }
这样就没有任何错误了,最终 String 的 所有权被转移给外面的调用者。
借用规则总结
总的来说,借用规则如下:
- 同一时刻,你只能拥有要么一个可变引用, 要么任意多个不可变引用
- 引用必须总是有效的
复合类型
顾名思义,复合类型是由其它类型组合而成的,最典型的就是结构体 struct 和枚举 enum。
来看一段代码,它使用我们之前学过的内容来构建文件操作:
#![allow(unused_variables)] type File = String; fn open(f: &mut File) -> bool { true } fn close(f: &mut File) -> bool { true } #[allow(dead_code)] fn read(f: &mut File, save_to: &mut Vec<u8>) -> ! { unimplemented!() } fn main() { let mut f1 = File::from("f1.txt"); open(&mut f1); //read(&mut f1, &mut vec![]); close(&mut f1); }
接下来我们的学习非常类似原型设计:有的方法只提供 API 接口,但是不提供具体实现。此外,有的变量在声明之后并未使用,因此在这个阶段我们需要排除一些编译器噪音(Rust 在编译的时候会扫描代码,变量声明后未使用会以 warning 警告的形式进行提示),引入 #![allow(unused_variables)] 属性标记,该标记会告诉编译器忽略未使用的变量,不要抛出 warning 警告,具体的常见编译器属性你可以在这里查阅:编译器属性标记。
read 函数也非常有趣,它返回一个 ! 类型,这个表明该函数是一个发散函数,不会返回任何值,包括 ()。unimplemented!() 告诉编译器该函数尚未实现,unimplemented!() 标记通常意味着我们期望快速完成主要代码,回头再通过搜索这些标记来完成次要代码,类似的标记还有 todo!(),当代码执行到这种未实现的地方时,程序会直接报错。你可以反注释 read(&mut f1, &mut vec![]); 这行,然后再观察下结果。
同时,从代码设计角度来看,关于文件操作的类型和函数应该组织在一起,散落得到处都是,是难以管理和使用的。而且通过 open(&mut f1) 进行调用,也远没有使用 f1.open() 来调用好,这就体现出了只使用基本类型的局限性:无法从更高的抽象层次去简化代码。
接下来,我们将引入一个高级数据结构 —— 结构体 struct,来看看复合类型是怎样更好的解决这类问题。 开始之前,先来看看 Rust 的重点也是难点:字符串 String 和 &str。
字符串
首先来看段很简单的代码:
fn main() { let my_name = "Pascal"; greet(my_name); } fn greet(name: String) { println!("Hello, {}!", name); }
greet 函数接受一个字符串类型的 name 参数,然后打印到终端控制台中,非常好理解,你们猜猜,这段代码能否通过编译?
error[E0308]: mismatched types
--> src/main.rs:3:11
|
3 | greet(my_name);
| ^^^^^^^
| |
| expected struct `std::string::String`, found `&str`
| help: try using a conversion method: `my_name.to_string()`
error: aborting due to previous error
Bingo,果然报错了,编译器提示 greet 函数需要一个 String 类型的字符串,却传入了一个 &str 类型的字符串.
在讲解字符串之前,先来看看什么是切片?
切片(slice)
它允许你引用集合中部分连续的元素序列,而不是引用整个集合。
对于字符串而言,切片就是对 String 类型中某一部分的引用,它看起来像这样:
#![allow(unused)] fn main() { let s = String::from("hello world"); let hello = &s[0..5]; let world = &s[6..11]; }
hello 没有引用整个 String s,而是引用了 s 的一部分内容,通过 [0..5] 的方式来指定。
这就是创建切片的语法,使用方括号包括的一个序列:[开始索引..终止索引],其中开始索引是切片中第一个元素的索引位置,而终止索引是最后一个元素后面的索引位置,也就是这是一个 右半开区间。在切片数据结构内部会保存开始的位置和切片的长度,其中长度是通过 终止索引 - 开始索引 的方式计算得来的。
对于 let world = &s[6..11]; 来说,world 是一个切片,该切片的指针指向 s 的第 7 个字节(索引从 0 开始, 6 是第 7 个字节),且该切片的长度是 5 个字节。
这两个是等效的:
#![allow(unused)] fn main() { let s = String::from("hello"); let slice = &s[0..2]; let slice = &s[..2]; }
同样的,如果你的切片想要包含 String 的最后一个字节,则可以这样使用:
#![allow(unused)] fn main() { let s = String::from("hello"); let len = s.len(); let slice = &s[4..len]; let slice = &s[4..]; }
你也可以截取完整的 String 切片:
#![allow(unused)] fn main() { let s = String::from("hello"); let len = s.len(); let slice = &s[0..len]; let slice = &s[..]; }
在对字符串使用切片语法时需要格外小心,切片的索引必须落在字符之间的边界位置,也就是 UTF-8 字符的边界,例如中文在 UTF-8 中占用三个字节,下面的代码就会崩溃:
#![allow(unused)] fn main() { let s = "中国人"; let a = &s[0..2]; println!("{}",a); }因为我们只取
s字符串的前两个字节,但是本例中每个汉字占用三个字节,因此没有落在边界处,也就是连中字都取不完整,此时程序会直接崩溃退出,如果改成&s[0..3],则可以正常通过编译。 因此,当你需要对字符串做切片索引操作时,需要格外小心这一点
字符串切片的类型标识是 &str,因此我们可以这样声明一个函数,输入 String 类型,返回它的切片: fn first_word(s: &String) -> &str 。
有了切片就可以写出这样的代码:
fn main() { let mut s = String::from("hello world"); let word = first_word(&s); s.clear(); // error! println!("the first word is: {}", word); } fn first_word(s: &String) -> &str { &s[..1] }
编译器报错如下:
error[E0502]: cannot borrow `s` as mutable because it is also borrowed as immutable
--> src/main.rs:18:5
|
16 | let word = first_word(&s);
| -- immutable borrow occurs here
17 |
18 | s.clear(); // error!
| ^^^^^^^^^ mutable borrow occurs here
19 |
20 | println!("the first word is: {}", word);
| ---- immutable borrow later used here
回忆一下借用的规则:当我们已经有了可变借用时,就无法再拥有不可变的借用。因为 clear 需要清空改变 String,因此它需要一个可变借用(利用 VSCode 可以看到该方法的声明是 pub fn clear(&mut self) ,参数是对自身的可变借用 );而之后的 println! 又使用了不可变借用,也就是在 s.clear() 处可变借用与不可变借用试图同时生效,因此编译无法通过。
从上述代码可以看出,Rust 不仅让我们的 API 更加容易使用,而且也在编译期就消除了大量错误!
其它切片
因为切片是对集合的部分引用,因此不仅仅字符串有切片,其它集合类型也有,例如数组:
#![allow(unused)] fn main() { let a = [1, 2, 3, 4, 5]; let slice = &a[1..3]; assert_eq!(slice, &[2, 3]); }
该数组切片的类型是 &[i32],数组切片和字符串切片的工作方式是一样的.
字符串字面量是切片
之前提到过字符串字面量,但是没有提到它的类型:
#![allow(unused)] fn main() { let s = "Hello, world!"; }
实际上,s 的类型是 &str,因此你也可以这样声明:
#![allow(unused)] fn main() { let s: &str = "Hello, world!"; }
该切片指向了程序可执行文件中的某个点,这也是为什么字符串字面量是不可变的,因为 &str 是一个不可变引用。
了解完切片,可以进入本节的正题了。
什么是字符串?
顾名思义,字符串是由字符组成的连续集合,但是在上一节中我们提到过,Rust 中的字符是 Unicode 类型,因此每个字符占据 4 个字节内存空间,但是在字符串中不一样,字符串是 UTF-8 编码,也就是字符串中的字符所占的字节数是变化的(1 - 4),这样有助于大幅降低字符串所占用的内存空间。
Rust 在语言级别,只有一种字符串类型: str,它通常是以引用类型出现 &str,也就是上文提到的字符串切片。虽然语言级别只有上述的 str 类型,但是在标准库里,还有多种不同用途的字符串类型,其中使用最广的即是 String 类型。
str 类型是硬编码进可执行文件,也无法被修改,但是 String 则是一个可增长、可改变且具有所有权的 UTF-8 编码字符串,当 Rust 用户提到字符串时,往往指的就是 String 类型和 &str 字符串切片类型,这两个类型都是 UTF-8 编码。
除了 String 类型的字符串,Rust 的标准库还提供了其他类型的字符串,例如 OsString, OsStr, CsString 和 CsStr 等,注意到这些名字都以 String 或者 Str 结尾了吗?它们分别对应的是具有所有权和被借用的变量。
String 与 &str 的转换
在之前的代码中,已经见到好几种从 &str 类型生成 String 类型的操作:
String::from("hello,world")"hello,world".to_string()
那么如何将 String 类型转为 &str 类型呢?答案很简单,取引用即可:
fn main() { let s = String::from("hello,world!"); say_hello(&s); say_hello(&s[..]); say_hello(s.as_str()); } fn say_hello(s: &str) { println!("{}",s); }
实际上这种灵活用法是因为 deref 隐式强制转换,具体我们会在Deref特征进行详细讲解。
字符串索引
在其它语言中,使用索引的方式访问字符串的某个字符或者子串是很正常的行为,但是在 Rust 中就会报错:
#![allow(unused)] fn main() { let s1 = String::from("hello"); let h = s1[0]; }
该代码会产生如下错误:
3 | let h = s1[0];
| ^^^^^ `String` cannot be indexed by `{integer}`
|
= help: the trait `Index<{integer}>` is not implemented for `String`
深入字符串内部
字符串的底层的数据存储格式实际上是[ u8 ],一个字节数组。对于 let hello = String::from("Hola"); 这行代码来说,Hola 的长度是 4 个字节,因为 "Hola" 中的每个字母在 UTF-8 编码中仅占用 1 个字节,但是对于下面的代码呢?
#![allow(unused)] fn main() { let hello = String::from("中国人"); }
如果问你该字符串多长,你可能会说 3,但是实际上是 9 个字节的长度,因为大部分常用汉字在 UTF-8 中的长度是 3 个字节,因此这种情况下对 hello 进行索引,访问 &hello[0] 没有任何意义,因为你取不到 中 这个字符,而是取到了这个字符三个字节中的第一个字节,这是一个非常奇怪而且难以理解的返回值。
字符串的不同表现形式
现在看一下用梵文写的字符串 “नमस्ते”, 它底层的字节数组如下形式:
#![allow(unused)] fn main() { [224, 164, 168, 224, 164, 174, 224, 164, 184, 224, 165, 141, 224, 164, 164, 224, 165, 135] }
长度是 18 个字节,这也是计算机最终存储该字符串的形式。如果从字符的形式去看,则是:
#![allow(unused)] fn main() { ['न', 'म', 'स', '्', 'त', 'े'] }
但是这种形式下,第四和六两个字母根本就不存在,没有任何意义,接着再从字母串的形式去看:
#![allow(unused)] fn main() { ["न", "म", "स्", "ते"] }
所以,可以看出来 Rust 提供了不同的字符串展现方式,这样程序可以挑选自己想要的方式去使用,而无需去管字符串从人类语言角度看长什么样。
还有一个原因导致了 Rust 不允许去索引字符串:因为索引操作,我们总是期望它的性能表现是 O(1),然而对于 String 类型来说,无法保证这一点,因为 Rust 可能需要从 0 开始去遍历字符串来定位合法的字符。
字符串切片
前文提到过,字符串切片是非常危险的操作,因为切片的索引是通过字节来进行,但是字符串又是 UTF-8 编码,因此你无法保证索引的字节刚好落在字符的边界上,例如:
#![allow(unused)] fn main() { let hello = "中国人"; let s = &hello[0..2]; }
运行上面的程序,会直接造成崩溃:
thread 'main' panicked at 'byte index 2 is not a char boundary; it is inside '中' (bytes 0..3) of `中国人`', src/main.rs:4:14
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
这里提示的很清楚,我们索引的字节落在了 中 字符的内部,这种返回没有任何意义。
因此在通过索引区间来访问字符串时,需要格外的小心,一不注意,就会导致你程序的崩溃!
操作字符串
由于 String 是可变字符串,下面介绍 Rust 字符串的修改,添加,删除等常用方法:
追加 (Push)
在字符串尾部可以使用 push() 方法追加字符 char,也可以使用 push_str() 方法追加字符串字面量。这两个方法都是在原有的字符串上追加,并不会返回新的字符串。由于字符串追加操作要修改原来的字符串,则该字符串必须是可变的,即字符串变量必须由 mut 关键字修饰。
示例代码如下:
fn main() { let mut s = String::from("Hello "); s.push_str("rust"); println!("追加字符串 push_str() -> {}", s); s.push('!'); println!("追加字符 push() -> {}", s); }
代码运行结果:
追加字符串 push_str() -> Hello rust
追加字符 push() -> Hello rust!
插入 (Insert)
可以使用 insert() 方法插入单个字符 char,也可以使用 insert_str() 方法插入字符串字面量,与 push() 方法不同,这俩方法需要传入两个参数,第一个参数是字符(串)插入位置的索引,第二个参数是要插入的字符(串),索引从 0 开始计数,如果越界则会发生错误。由于字符串插入操作要修改原来的字符串,则该字符串必须是可变的,即字符串变量必须由 mut 关键字修饰。
示例代码如下:
fn main() { let mut s = String::from("Hello rust!"); s.insert(5, ','); println!("插入字符 insert() -> {}", s); s.insert_str(6, " I like"); println!("插入字符串 insert_str() -> {}", s); }
代码运行结果:
插入字符 insert() -> Hello, rust!
插入字符串 insert_str() -> Hello, I like rust!
替换 (Replace)
如果想要把字符串中的某个字符串替换成其它的字符串,那可以使用 replace() 方法。与替换有关的方法有三个。
1、replace
该方法可适用于 String 和 &str 类型。replace() 方法接收两个参数,第一个参数是要被替换的字符串,第二个参数是新的字符串。该方法会替换所有匹配到的字符串。该方法是返回一个新的字符串,而不是操作原来的字符串。
示例代码如下:
fn main() { let string_replace = String::from("I like rust. Learning rust is my favorite!"); let new_string_replace = string_replace.replace("rust", "RUST"); dbg!(new_string_replace); }
代码运行结果:
new_string_replace = "I like RUST. Learning RUST is my favorite!"
2、replacen
该方法可适用于 String 和 &str 类型。replacen() 方法接收三个参数,前两个参数与 replace() 方法一样,第三个参数则表示替换的个数。该方法是返回一个新的字符串,而不是操作原来的字符串。
示例代码如下:
fn main() { let string_replace = "I like rust. Learning rust is my favorite!"; let new_string_replacen = string_replace.replacen("rust", "RUST", 1); dbg!(new_string_replacen); }
代码运行结果:
new_string_replacen = "I like RUST. Learning rust is my favorite!"
3、replace_range
该方法仅适用于 String 类型。replace_range 接收两个参数,第一个参数是要替换字符串的范围(Range),第二个参数是新的字符串。该方法是直接操作原来的字符串,不会返回新的字符串。该方法需要使用 mut 关键字修饰。
示例代码如下:
fn main() { let mut string_replace_range = String::from("I like rust!"); string_replace_range.replace_range(7..8, "R"); dbg!(string_replace_range); }
代码运行结果:
string_replace_range = "I like Rust!"
删除 (Delete)
与字符串删除相关的方法有 4 个,他们分别是 pop(),remove(),truncate(),clear()。这四个方法仅适用于 String 类型。
1、 pop —— 删除并返回字符串的最后一个字符
该方法是直接操作原来的字符串。但是存在返回值,其返回值是一个 Option 类型,如果字符串为空,则返回 None。
示例代码如下:
fn main() { let mut string_pop = String::from("rust pop 中文!"); let p1 = string_pop.pop(); let p2 = string_pop.pop(); dbg!(p1); dbg!(p2); dbg!(string_pop); }
代码运行结果:
p1 = Some(
'!',
)
p2 = Some(
'文',
)
string_pop = "rust pop 中"
2、 remove —— 删除并返回字符串中指定位置的字符
该方法是直接操作原来的字符串。但是存在返回值,其返回值是删除位置的字符串,只接收一个参数,表示该字符起始索引位置。remove() 方法是按照字节来处理字符串的,如果参数所给的位置不是合法的字符边界,则会发生错误。
示例代码如下:
fn main() { let mut string_remove = String::from("测试remove方法"); println!( "string_remove 占 {} 个字节", std::mem::size_of_val(string_remove.as_str()) ); // 删除第一个汉字 string_remove.remove(0); // 下面代码会发生错误 // string_remove.remove(1); // 直接删除第二个汉字 // string_remove.remove(3); dbg!(string_remove); }
代码运行结果:
string_remove 占 18 个字节
string_remove = "试remove方法"
3、truncate —— 删除字符串中从指定位置开始到结尾的全部字符
该方法是直接操作原来的字符串。无返回值。该方法 truncate() 方法是按照字节来处理字符串的,如果参数所给的位置不是合法的字符边界,则会发生错误。
示例代码如下:
fn main() { let mut string_truncate = String::from("测试truncate"); string_truncate.truncate(3); dbg!(string_truncate); }
代码运行结果:
string_truncate = "测"
4、clear —— 清空字符串
该方法是直接操作原来的字符串。调用后,删除字符串中的所有字符,相当于 truncate() 方法参数为 0 的时候。
示例代码如下:
fn main() { let mut string_clear = String::from("string clear"); string_clear.clear(); dbg!(string_clear); }
代码运行结果:
string_clear = ""
连接 (Concatenate)
1、使用 + 或者 += 连接字符串
使用 + 或者 += 连接字符串,要求右边的参数必须为字符串的切片引用(Slice)类型。其实当调用 + 的操作符时,相当于调用了 std::string 标准库中的 add() 方法,这里 add() 方法的第二个参数是一个引用的类型。因此我们在使用 +, 必须传递切片引用类型。不能直接传递 String 类型。+ 和 += 都是返回一个新的字符串。所以变量声明可以不需要 mut 关键字修饰。
示例代码如下:
fn main() { let string_append = String::from("hello "); let string_rust = String::from("rust"); // &string_rust会自动解引用为&str let result = string_append + &string_rust; let mut result = result + "!"; result += "!!!"; println!("连接字符串 + -> {}", result); }
代码运行结果:
连接字符串 + -> hello rust!!!!
add() 方法的定义:
#![allow(unused)] fn main() { fn add(self, s: &str) -> String }
因为该方法涉及到更复杂的特征功能,因此我们这里简单说明下:
fn main() { let s1 = String::from("hello,"); let s2 = String::from("world!"); // 在下句中,s1的所有权被转移走了,因此后面不能再使用s1 let s3 = s1 + &s2; assert_eq!(s3,"hello,world!"); // 下面的语句如果去掉注释,就会报错 // println!("{}",s1); }
self 是 String 类型的字符串 s1,该函数说明,只能将 &str 类型的字符串切片添加到 String 类型的 s1 上,然后返回一个新的 String 类型,所以 let s3 = s1 + &s2; 就很好解释了,将 String 类型的 s1 与 &str 类型的 s2 进行相加,最终得到 String 类型的 s3。
由此可推,以下代码也是合法的:
#![allow(unused)] fn main() { let s1 = String::from("tic"); let s2 = String::from("tac"); let s3 = String::from("toe"); // String = String + &str + &str + &str + &str let s = s1 + "-" + &s2 + "-" + &s3; }
String + &str返回一个 String,然后再继续跟一个 &str 进行 + 操作,返回一个 String 类型,不断循环,最终生成一个 s,也是 String 类型。
s1 这个变量通过调用 add() 方法后,所有权被转移到 add() 方法里面, add() 方法调用后就被释放了,同时 s1 也被释放了。再使用 s1 就会发生错误。这里涉及到所有权转移(Move)的相关知识。
2、使用 format! 连接字符串
format! 这种方式适用于 String 和 &str 。format! 的用法与 print! 的用法类似, 示例代码如下:
fn main() { let s1 = "hello"; let s2 = String::from("rust"); let s = format!("{} {}!", s1, s2); println!("{}", s); }
代码运行结果:
hello rust!
字符串转义
我们可以通过转义的方式 \ 输出 ASCII 和 Unicode 字符。
fn main() { // 通过 \ + 字符的十六进制表示,转义输出一个字符 let byte_escape = "I'm writing \x52\x75\x73\x74!"; println!("What are you doing\x3F (\\x3F means ?) {}", byte_escape); // \u 可以输出一个 unicode 字符 let unicode_codepoint = "\u{211D}"; let character_name = "\"DOUBLE-STRUCK CAPITAL R\""; println!( "Unicode character {} (U+211D) is called {}", unicode_codepoint, character_name ); // 换行了也会保持之前的字符串格式 let long_string = "String literals can span multiple lines. The linebreak and indentation here ->\ <- can be escaped too!"; println!("{}", long_string); }
当然,在某些情况下,可能你会希望保持字符串的原样,不要转义:
fn main() { println!("{}", "hello \\x52\\x75\\x73\\x74"); let raw_str = r"Escapes don't work here: \x3F \u{211D}"; println!("{}", raw_str); // 如果字符串包含双引号,可以在开头和结尾加 # let quotes = r#"And then I said: "There is no escape!""#; println!("{}", quotes); // 如果还是有歧义,可以继续增加,没有限制 let longer_delimiter = r###"A string with "# in it. And even "##!"###; println!("{}", longer_delimiter); }
操作 UTF-8 字符串
前文提到了几种使用 UTF-8 字符串的方式,下面来一一说明。
字符
如果你想要以 Unicode 字符的方式遍历字符串,最好的办法是使用 chars 方法,例如:
#![allow(unused)] fn main() { for c in "中国人".chars() { println!("{}", c); } }
输出如下
中
国
人
字节
这种方式是返回字符串的底层字节数组表现形式:
#![allow(unused)] fn main() { for b in "中国人".bytes() { println!("{}", b); } }
输出如下:
228
184
173
229
155
189
228
186
186
获取子串
想要准确的从 UTF-8 字符串中获取子串是较为复杂的事情,例如想要从 holla中国人नमस्ते 这种变长的字符串中取出某一个子串,使用标准库你是做不到的。
你需要在 crates.io 上搜索 utf8 来寻找想要的功能。
可以考虑尝试下这个库:utf8_slice。
字符串深度剖析
那么问题来了,为啥 String 可变,而字符串字面值 str 却不可以?
就字符串字面值来说,我们在编译时就知道其内容,最终字面值文本被直接硬编码进可执行文件中,这使得字符串字面值快速且高效,这主要得益于字符串字面值的不可变性。不幸的是,我们不能为了获得这种性能,而把每一个在编译时大小未知的文本都放进内存中(你也做不到!),因为有的字符串是在程序运行得过程中动态生成的。
对于 String 类型,为了支持一个可变、可增长的文本片段,需要在堆上分配一块在编译时未知大小的内存来存放内容,这些都是在程序运行时完成的:
- 首先向操作系统请求内存来存放
String对象 - 在使用完成后,将内存释放,归还给操作系统
其中第一部分由 String::from 完成,它创建了一个全新的 String。
重点来了,到了第二部分,就是百家齐放的环节,在有垃圾回收 GC 的语言中,GC 来负责标记并清除这些不再使用的内存对象,这个过程都是自动完成,无需开发者关心,非常简单好用;但是在无 GC 的语言中,需要开发者手动去释放这些内存对象,就像创建对象需要通过编写代码来完成一样,未能正确释放对象造成的后果简直不可估量。
对于 Rust 而言,安全和性能是写到骨子里的核心特性,如果使用 GC,那么会牺牲性能;如果使用手动管理内存,那么会牺牲安全,这该怎么办?为此,Rust 的开发者想出了一个无比惊艳的办法:变量在离开作用域后,就自动释放其占用的内存:
#![allow(unused)] fn main() { { let s = String::from("hello"); // 从此处起,s 是有效的 // 使用 s } // 此作用域已结束, // s 不再有效,内存被释放 }
与其它系统编程语言的 free 函数相同,Rust 也提供了一个释放内存的函数: drop,但是不同的是,其它语言要手动调用 free 来释放每一个变量占用的内存,而 Rust 则在变量离开作用域时,自动调用 drop 函数: 上面代码中,Rust 在结尾的 } 处自动调用 drop。
其实,在 C++ 中,也有这种概念: Resource Acquisition Is Initialization (RAII)。如果你使用过 RAII 模式的话应该对 Rust 的
drop函数并不陌生
元组
元组是由多种类型组合到一起形成的,因此它是复合类型,元组的长度是固定的,元组中元素的顺序也是固定的。
可以通过以下语法创建一个元组:
fn main() { let tup: (i32, f64, u8) = (500, 6.4, 1); }
变量 tup 被绑定了一个元组值 (500, 6.4, 1),该元组的类型是 (i32, f64, u8).
可以使用模式匹配或者 . 操作符来获取元组中的值。
用模式匹配解构元组
fn main() { let tup = (500, 6.4, 1); let (x, y, z) = tup; println!("The value of y is: {}", y); }
上述代码首先创建一个元组,然后将其绑定到 tup 上,接着使用 let (x, y, z) = tup; 来完成一次模式匹配,因为元组是 (n1, n2, n3) 形式的,因此我们用一模一样的 (x, y, z) 形式来进行匹配,元组中对应的值会绑定到变量 x, y, z上。这就是解构:用同样的形式把一个复杂对象中的值匹配出来。
用 . 来访问元组
模式匹配可以让我们一次性把元组中的值全部或者部分获取出来,如果只想要访问某个特定元素,那模式匹配就略显繁琐,对此,Rust 提供了 . 的访问方式:
fn main() { let x: (i32, f64, u8) = (500, 6.4, 1); let five_hundred = x.0; let six_point_four = x.1; let one = x.2; }
和其它语言的数组、字符串一样,元组的索引从 0 开始。
元组的使用示例
元组在函数返回值场景很常用,例如下面的代码,可以使用元组返回多个值:
fn main() { let s1 = String::from("hello"); let (s2, len) = calculate_length(s1); println!("The length of '{}' is {}.", s2, len); } fn calculate_length(s: String) -> (String, usize) { let length = s.len(); // len() 返回字符串的长度 (s, length) }
calculate_length 函数接收 s1 字符串的所有权,然后计算字符串的长度,接着把字符串所有权和字符串长度再返回给 s2 和 len 变量。
在其他语言中,可以用结构体来声明一个三维空间中的点,例如 Point(10, 20, 30),虽然使用 Rust 元组也可以做到:(10, 20, 30),但是这样写有个非常重大的缺陷:
不具备任何清晰的含义,在下一章节中,会提到一种与元组类似的结构体,元组结构体,可以解决这个问题。
结构体
结构体跟元组有些相像:都是由多种类型组合而成。但是与元组不同的是,结构体可以为内部的每个字段起一个富有含义的名称。
结构体语法
定义结构体
一个结构体由几部分组成:
- 通过关键字
struct定义 - 一个清晰明确的结构体
名称 - 几个有名字的结构体
字段
例如, 以下结构体定义了某网站的用户:
#![allow(unused)] fn main() { struct User { active: bool, username: String, email: String, sign_in_count: u64, } }
该结构体名称是 User,拥有 4 个字段,且每个字段都有对应的字段名及类型声明,例如 username 代表了用户名,是一个可变的 String 类型。
创建结构体实例
为了使用上述结构体,我们需要创建 User 结构体的实例:
#![allow(unused)] fn main() { let user1 = User { email: String::from("someone@example.com"), username: String::from("someusername123"), active: true, sign_in_count: 1, }; }
有几点值得注意:
- 初始化实例时,每个字段都需要进行初始化
- 初始化时的字段顺序不需要和结构体定义时的顺序一致
访问结构体字段
通过 . 操作符即可访问结构体实例内部的字段值,也可以修改它们:
#![allow(unused)] fn main() { let mut user1 = User { email: String::from("someone@example.com"), username: String::from("someusername123"), active: true, sign_in_count: 1, }; user1.email = String::from("anotheremail@example.com"); }
需要注意的是,必须要将结构体实例声明为可变的,才能修改其中的字段,Rust 不支持将某个结构体某个字段标记为可变。
简化结构体创建
下面的函数类似一个构建函数,返回了 User 结构体的实例:
#![allow(unused)] fn main() { fn build_user(email: String, username: String) -> User { User { email: email, username: username, active: true, sign_in_count: 1, } } }
它接收两个字符串参数: email 和 username,然后使用它们来创建一个 User 结构体,并且返回。可以注意到这两行: email: email 和 username: username,非常的扎眼,因为实在有些啰嗦,如果你从 TypeScript 过来,肯定会鄙视 Rust 一番,不过好在,它也不是无可救药:
#![allow(unused)] fn main() { fn build_user(email: String, username: String) -> User { User { email, username, active: true, sign_in_count: 1, } } }
如上所示,当函数参数和结构体字段同名时,可以直接使用缩略的方式进行初始化,跟 TypeScript 中一模一样。
结构体更新语法
在实际场景中,有一种情况很常见:根据已有的结构体实例,创建新的结构体实例,例如根据已有的 user1 实例来构建 user2:
#![allow(unused)] fn main() { let user2 = User { active: user1.active, username: user1.username, email: String::from("another@example.com"), sign_in_count: user1.sign_in_count, }; }
好在 Rust 为我们提供了 结构体更新语法:
#![allow(unused)] fn main() { let user2 = User { email: String::from("another@example.com"), ..user1 }; }
因为 user2 仅仅在 email 上与 user1 不同,因此我们只需要对 email 进行赋值,剩下的通过结构体更新语法 ..user1 即可完成。
.. 语法表明凡是我们没有显式声明的字段,全部从 user1 中自动获取。需要注意的是 ..user1 必须在结构体的尾部使用。
结构体更新语法跟赋值语句
=非常相像,因此在上面代码中,user1的部分字段所有权被转移到user2中:username字段发生了所有权转移,作为结果,user1无法再被使用。聪明的读者肯定要发问了:明明有三个字段进行了自动赋值,为何只有
username发生了所有权转移?仔细回想一下所有权那一节的内容,我们提到了
Copy特征:实现了Copy特征的类型无需所有权转移,可以直接在赋值时进行 数据拷贝,其中bool和u64类型就实现了Copy特征,因此active和sign_in_count字段在赋值给user2时,仅仅发生了拷贝,而不是所有权转移。值得注意的是:
username所有权被转移给了user2,导致了user1无法再被使用,但是并不代表user1内部的其它字段不能被继续使用,例如:
#[derive(Debug)] struct User { active: bool, username: String, email: String, sign_in_count: u64, } fn main() { let user1 = User { email: String::from("someone@example.com"), username: String::from("someusername123"), active: true, sign_in_count: 1, }; let user2 = User { active: user1.active, username: user1.username, email: String::from("another@example.com"), sign_in_count: user1.sign_in_count, }; println!("{}", user1.active); // 下面这行会报错 println!("{:?}", user1); }
结构体的内存排列
先来看以下代码:
#[derive(Debug)] struct File { name: String, data: Vec<u8>, } fn main() { let f1 = File { name: String::from("f1.txt"), data: Vec::new(), }; let f1_name = &f1.name; let f1_length = &f1.data.len(); println!("{:?}", f1); println!("{} is {} bytes long", f1_name, f1_length); }
上面定义的 File 结构体在内存中的排列如下图所示:

从图中可以清晰的看出 File 结构体两个字段 name 和 data 分别拥有底层两个 [u8] 数组的所有权(String 类型的底层也是 [u8] 数组),通过 ptr 指针指向底层数组的内存地址,这里你可以把 ptr 指针理解为 Rust 中的引用类型。
该图片也侧面印证了:把结构体中具有所有权的字段转移出去后,将无法再访问该字段,但是可以正常访问其它的字段。
元组结构体(Tuple Struct)
结构体必须要有名称,但是结构体的字段可以没有名称,这种结构体长得很像元组,因此被称为元组结构体,例如:
#![allow(unused)] fn main() { struct Color(i32, i32, i32); struct Point(i32, i32, i32); let black = Color(0, 0, 0); let origin = Point(0, 0, 0); }
元组结构体在你希望有一个整体名称,但是又不关心里面字段的名称时将非常有用。例如上面的 Point 元组结构体,众所周知 3D 点是 (x, y, z) 形式的坐标点,因此我们无需再为内部的字段逐一命名为:x, y, z。
单元结构体(Unit-like Struct)
如果你定义一个类型,但是不关心该类型的内容, 只关心它的行为时,就可以使用 单元结构体:
#![allow(unused)] fn main() { struct AlwaysEqual; let subject = AlwaysEqual; // 我们不关心 AlwaysEqual 的字段数据,只关心它的行为,因此将它声明为单元结构体,然后再为它实现某个特征 impl SomeTrait for AlwaysEqual { } }
结构体数据的所有权
在之前的 User 结构体的定义中,有一处细节:我们使用了自身拥有所有权的 String 类型而不是基于引用的 &str 字符串切片类型。这是一个有意而为之的选择:因为我们想要这个结构体拥有它所有的数据,而不是从其它地方借用数据。
你也可以让 User 结构体从其它对象借用数据,不过这么做,就需要引入生命周期(lifetimes)这个新概念(也是一个复杂的概念),简而言之,生命周期能确保结构体的作用范围要比它所借用的数据的作用范围要小。
总之,如果你想在结构体中使用一个引用,就必须加上生命周期,否则就会报错:
struct User { username: &str, email: &str, sign_in_count: u64, active: bool, } fn main() { let user1 = User { email: "someone@example.com", username: "someusername123", active: true, sign_in_count: 1, }; }
编译器会抱怨它需要生命周期标识符:
error[E0106]: missing lifetime specifier
--> src/main.rs:2:15
|
2 | username: &str,
| ^ expected named lifetime parameter // 需要一个生命周期
|
help: consider introducing a named lifetime parameter // 考虑像下面的代码这样引入一个生命周期
|
1 ~ struct User<'a> {
2 ~ username: &'a str,
|
error[E0106]: missing lifetime specifier
--> src/main.rs:3:12
|
3 | email: &str,
| ^ expected named lifetime parameter
|
help: consider introducing a named lifetime parameter
|
1 ~ struct User<'a> {
2 | username: &str,
3 ~ email: &'a str,
|
未来在生命周期中会讲到如何修复这个问题以便在结构体中存储引用,不过在那之前,我们会避免在结构体中使用引用类型。
使用 #[derive(Debug)] 来打印结构体的信息
在前面的代码中我们使用 #[derive(Debug)] 对结构体进行了标记,这样才能使用 println!("{:?}", s); 的方式对其进行打印输出,如果不加,看看会发生什么:
struct Rectangle { width: u32, height: u32, } fn main() { let rect1 = Rectangle { width: 30, height: 50, }; println!("rect1 is {}", rect1); }
首先可以观察到,上面使用了 {} 而不是之前的 {:?},运行后报错:
error[E0277]: `Rectangle` doesn't implement `std::fmt::Display`
提示我们结构体 Rectangle 没有实现 Display 特征,这是因为如果我们使用 {} 来格式化输出,那对应的类型就必须实现 Display 特征,以前学习的基本类型,都默认实现了该特征:
fn main() { let v = 1; let b = true; println!("{}, {}", v, b); }
上面代码不会报错,那么结构体为什么不默认实现 Display 特征呢?原因在于结构体较为复杂,例如考虑以下问题:你想要逗号对字段进行分割吗?需要括号吗?加在什么地方?所有的字段都应该显示?类似的还有很多,由于这种复杂性,Rust 不希望猜测我们想要的是什么,而是把选择权交给我们自己来实现:如果要用 {} 的方式打印结构体,那就自己实现 Display 特征。
接下来继续阅读报错:
= help: the trait `std::fmt::Display` is not implemented for `Rectangle`
= note: in format strings you may be able to use `{:?}` (or {:#?} for pretty-print) instead
上面提示我们使用 {:?} 来试试,这个方式我们在本文的前面也见过,下面来试试:
#![allow(unused)] fn main() { println!("rect1 is {:?}", rect1); }
可是依然无情报错了:
error[E0277]: `Rectangle` doesn't implement `Debug`
好在,聪明的编译器又一次给出了提示:
= help: the trait `Debug` is not implemented for `Rectangle`
= note: add `#[derive(Debug)]` to `Rectangle` or manually `impl Debug for Rectangle`
让我们实现 Debug 特征,Oh No,就是不想实现 Display 特征,才用的 {:?},怎么又要实现 Debug,但是仔细看,提示中有一行: add #[derive(Debug)] to Rectangle, 哦?这不就是我们前文一直在使用的吗?
首先,Rust 默认不会为我们实现 Debug,为了实现,有两种方式可以选择:
- 手动实现
- 使用
derive派生实现
后者简单的多,但是也有限制,这里我们就不再深入讲解,来看看该如何使用:
#[derive(Debug)] struct Rectangle { width: u32, height: u32, } fn main() { let rect1 = Rectangle { width: 30, height: 50, }; println!("rect1 is {:?}", rect1); }
此时运行程序,就不再有错误,输出如下:
$ cargo run
rect1 is Rectangle { width: 30, height: 50 }
这个输出格式看上去也不赖嘛,虽然未必是最好的。这种格式是 Rust 自动为我们提供的实现,看上基本就跟结构体的定义形式一样。
当结构体较大时,我们可能希望能够有更好的输出表现,此时可以使用 {:#?} 来替代 {:?},输出如下:
rect1 is Rectangle {
width: 30,
height: 50,
}
此时结构体的输出跟我们创建时候的代码几乎一模一样了!当然,如果大家还是不满足,那最好还是自己实现 Display 特征,以向用户更美的展示你的私藏结构体。。
还有一个简单的输出 debug 信息的方法,那就是使用 dbg! 宏,它会拿走表达式的所有权,然后打印出相应的文件名、行号等 debug 信息,当然还有我们需要的表达式的求值结果。除此之外,它最终还会把表达式值的所有权返回!
dbg!输出到标准错误输出stderr,而println!输出到标准输出stdout
下面的例子中清晰的展示了 dbg! 如何在打印出信息的同时,还把表达式的值赋给了 width:
#[derive(Debug)] struct Rectangle { width: u32, height: u32, } fn main() { let scale = 2; let rect1 = Rectangle { width: dbg!(30 * scale), height: 50, }; dbg!(&rect1); }
最终的 debug 输出如下:
$ cargo run
[src/main.rs:10] 30 * scale = 60
[src/main.rs:14] &rect1 = Rectangle {
width: 60,
height: 50,
}
可以看到,我们想要的 debug 信息几乎都有了:代码所在的文件名、行号、表达式以及表达式的值,简直完美!
枚举
枚举(enum 或 enumeration)允许你通过列举可能的成员来定义一个枚举类型,例如扑克牌花色:
#![allow(unused)] fn main() { enum PokerSuit { Clubs, Spades, Diamonds, Hearts, } }
扑克总共有四种花色,而这里我们枚举出所有的可能值,这也正是 枚举 名称的由来。
任何一张扑克,它的花色肯定会落在四种花色中,而且也只会落在其中一个花色上,这种特性非常适合枚举的使用,因为枚举值只可能是其中某一个成员。抽象来看,四种花色尽管是不同的花色,但是它们都是扑克花色这个概念,因此当某个函数处理扑克花色时,可以把它们当作相同的类型进行传参。
细心的读者应该注意到,我们对之前的 枚举类型 和 枚举值 进行了重点标注,这是容易被混淆的概念,总而言之:
枚举类型是一个类型,它会包含所有可能的枚举成员, 而枚举值是该类型中的具体某个成员的实例。
枚举值
现在来创建 PokerSuit 枚举类型的两个成员实例:
#![allow(unused)] fn main() { let heart = PokerSuit::Hearts; let diamond = PokerSuit::Diamonds; }
我们通过 :: 操作符来访问 PokerSuit 下的具体成员,从代码可以清晰看出,heart 和 diamond 都是 PokerSuit 枚举类型的,接着可以定义一个函数来使用它们:
fn main() { let heart = PokerSuit::Hearts; let diamond = PokerSuit::Diamonds; print_suit(heart); print_suit(diamond); } fn print_suit(card: PokerSuit) { println!("{:?}",card); }
print_suit 函数的参数类型是 PokerSuit,因此我们可以把 heart 和 diamond 传给它,虽然 heart 是基于 PokerSuit 下的 Hearts 成员实例化的,但是它是货真价实的 PokerSuit 枚举类型。
接下来,我们想让扑克牌变得更加实用,那么需要给每张牌赋予一个值:A(1)-K(13),这样再加上花色,就是一张真实的扑克牌了,例如红心 A。
目前来说,枚举值还不能带有值,因此先用结构体来实现:
enum PokerSuit { Clubs, Spades, Diamonds, Hearts, } struct PokerCard { suit: PokerSuit, value: u8 } fn main() { let c1 = PokerCard { suit: PokerSuit::Clubs, value: 1, }; let c2 = PokerCard { suit: PokerSuit::Diamonds, value: 12, }; }
这段代码很好的完成了它的使命,通过结构体 PokerCard 来代表一张牌,结构体的 suit 字段表示牌的花色,类型是 PokerSuit 枚举类型,value 字段代表扑克牌的数值。
可以吗?可以!好吗?说实话,不咋地,因为还有简洁得多的方式来实现:
enum PokerCard { Clubs(u8), Spades(u8), Diamonds(u8), Hearts(u8), } fn main() { let c1 = PokerCard::Spades(5); let c2 = PokerCard::Diamonds(13); }
直接将数据信息关联到枚举成员上,省去近一半的代码,这种实现是不是更优雅?
不仅如此,同一个枚举类型下的不同成员还能持有不同的数据类型,例如让某些花色打印 1-13 的字样,另外的花色打印上 A-K 的字样:
enum PokerCard { Clubs(u8), Spades(u8), Diamonds(char), Hearts(char), } fn main() { let c1 = PokerCard::Spades(5); let c2 = PokerCard::Diamonds('A'); }
回想一下,遇到这种不同类型的情况,再用我们之前的结构体实现方式,可行吗?也许可行,但是会复杂很多。
再来看一个来自标准库中的例子:
#![allow(unused)] fn main() { struct Ipv4Addr { // --snip-- } struct Ipv6Addr { // --snip-- } enum IpAddr { V4(Ipv4Addr), V6(Ipv6Addr), } }
这个例子跟我们之前的扑克牌很像,只不过枚举成员包含的类型更复杂了,变成了结构体:分别通过 Ipv4Addr 和 Ipv6Addr 来定义两种不同的 IP 数据。
从这些例子可以看出,任何类型的数据都可以放入枚举成员中: 例如字符串、数值、结构体甚至另一个枚举。
增加一些挑战?先看以下代码:
enum Message { Quit, Move { x: i32, y: i32 }, Write(String), ChangeColor(i32, i32, i32), } fn main() { let m1 = Message::Quit; let m2 = Message::Move{x:1,y:1}; let m3 = Message::ChangeColor(255,255,0); }
该枚举类型代表一条消息,它包含四个不同的成员:
Quit没有任何关联数据Move包含一个匿名结构体Write包含一个String字符串ChangeColor包含三个i32
当然,我们也可以用结构体的方式来定义这些消息:
#![allow(unused)] fn main() { struct QuitMessage; // 单元结构体 struct MoveMessage { x: i32, y: i32, } struct WriteMessage(String); // 元组结构体 struct ChangeColorMessage(i32, i32, i32); // 元组结构体 }
由于每个结构体都有自己的类型,因此我们无法在需要同一类型的地方进行使用,例如某个函数它的功能是接受消息并进行发送,那么用枚举的方式,就可以接收不同的消息,但是用结构体,该函数无法接受 4 个不同的结构体作为参数。
而且从代码规范角度来看,枚举的实现更简洁,代码内聚性更强,不像结构体的实现,分散在各个地方。
同一化类型
最后,再用一个实际项目中的简化片段,来结束枚举类型的语法学习。
例如我们有一个 WEB 服务,需要接受用户的长连接,假设连接有两种:TcpStream 和 TlsStream,但是我们希望对这两个连接的处理流程相同,也就是用同一个函数来处理这两个连接,代码如下:
#![allow(unused)] fn main() { fn new (stream: TcpStream) { let mut s = stream; if tls { s = negotiate_tls(stream) } // websocket是一个WebSocket<TcpStream>或者 // WebSocket<native_tls::TlsStream<TcpStream>>类型 websocket = WebSocket::from_raw_socket( stream, ......) } }
此时,枚举类型就能帮上大忙:
#![allow(unused)] fn main() { enum Websocket { Tcp(Websocket<TcpStream>), Tls(Websocket<native_tls::TlsStream<TcpStream>>), } }
Option 枚举用于处理空值
在其它编程语言中,往往都有一个 null 关键字,该关键字用于表明一个变量当前的值为空,也就是不存在值。当你对这些 null 进行操作时,例如调用一个方法,就会直接抛出null 异常,导致程序的崩溃,因此我们在编程时需要格外的小心去处理这些 null 空值。
Tony Hoare,
null的发明者,曾经说过一段非常有名的话我称之为我十亿美元的错误。当时,我在使用一个面向对象语言设计第一个综合性的面向引用的类型系统。我的目标是通过编译器的自动检查来保证所有引用的使用都应该是绝对安全的。不过在设计过程中,我未能抵抗住诱惑,引入了空引用的概念,因为它非常容易实现。就是因为这个决策,引发了无数错误、漏洞和系统崩溃,在之后的四十多年中造成了数十亿美元的苦痛和伤害。
尽管如此,空值的表达依然非常有意义,因为空值表示当前时刻变量的值是缺失的。有鉴于此,Rust 吸取了众多教训,决定抛弃 null,而改为使用 Option 枚举变量来表述这种结果。
Option 枚举包含两个成员,一个成员表示含有值:Some(T), 另一个表示没有值:None,定义如下:
#![allow(unused)] fn main() { enum Option<T> { Some(T), None, } }
其中 T 是泛型参数,Some(T)表示该枚举成员的数据类型是 T,换句话说,Some 可以包含任何类型的数据。
Option<T> 枚举是如此有用以至于它被包含在了 prelude(prelude 属于 Rust 标准库,Rust 会将最常用的类型、函数等提前引入其中,省得我们再手动引入)之中,你不需要将其显式引入作用域。另外,它的成员 Some 和 None 也是如此,无需使用 Option:: 前缀就可直接使用 Some 和 None。
再来看以下代码:
#![allow(unused)] fn main() { let some_number = Some(5); let some_string = Some("a string"); let absent_number: Option<i32> = None; }
如果使用 None 而不是 Some,需要告诉 Rust Option<T> 是什么类型的,因为编译器只通过 None 值无法推断出 Some 成员保存的值的类型。
当有一个 Some 值时,我们就知道存在一个值,而这个值保存在 Some 中。当有个 None 值时,在某种意义上,它跟空值具有相同的意义:并没有一个有效的值。那么,Option<T> 为什么就比空值要好呢?
简而言之,因为 Option<T> 和 T(这里 T 可以是任何类型)是不同的类型,例如,这段代码不能编译,因为它尝试将 Option<i8>(Option<T>) 与 i8(T) 相加:
#![allow(unused)] fn main() { let x: i8 = 5; let y: Option<i8> = Some(5); let sum = x + y; }
如果运行这些代码,将得到类似这样的错误信息:
error[E0277]: the trait bound `i8: std::ops::Add<std::option::Option<i8>>` is
not satisfied
-->
|
5 | let sum = x + y;
| ^ no implementation for `i8 + std::option::Option<i8>`
|
很好!事实上,错误信息意味着 Rust 不知道该如何将 Option<i8> 与 i8 相加,因为它们的类型不同。当在 Rust 中拥有一个像 i8 这样类型的值时,编译器确保它总是有一个有效的值,我们可以放心使用而无需做空值检查。只有当使用 Option<i8>(或者任何用到的类型)的时候才需要担心可能没有值,而编译器会确保我们在使用值之前处理了为空的情况。
换句话说,在对 Option<T> 进行 T 的运算之前必须将其转换为 T。通常这能帮助我们捕获到空值最常见的问题之一:期望某值不为空但实际上为空的情况。
不再担心会错误的使用一个空值,会让你对代码更加有信心。为了拥有一个可能为空的值,你必须要显式的将其放入对应类型的 Option<T> 中。接着,当使用这个值时,必须明确的处理值为空的情况。只要一个值不是 Option<T> 类型,你就 可以 安全的认定它的值不为空。这是 Rust 的一个经过深思熟虑的设计决策,来限制空值的泛滥以增加 Rust 代码的安全性。
那么当有一个 Option<T> 的值时,如何从 Some 成员中取出 T 的值来使用它呢?Option<T> 枚举拥有大量用于各种情况的方法:你可以查看它的文档。熟悉 Option<T> 的方法将对你的 Rust 之旅非常有用。
总的来说,为了使用 Option<T> 值,需要编写处理每个成员的代码。你想要一些代码只当拥有 Some(T) 值时运行,允许这些代码使用其中的 T。也希望一些代码在值为 None 时运行,这些代码并没有一个可用的 T 值。match 表达式就是这么一个处理枚举的控制流结构:它会根据枚举的成员运行不同的代码,这些代码可以使用匹配到的值中的数据。
这里先简单看一下 match 的大致模样,在模式匹配中,我们会详细讲解:
#![allow(unused)] fn main() { fn plus_one(x: Option<i32>) -> Option<i32> { match x { None => None, Some(i) => Some(i + 1), } } let five = Some(5); let six = plus_one(five); let none = plus_one(None); }
plus_one 通过 match 来处理不同 Option 的情况。
数组
在 Rust 中,最常用的数组有两种,第一种是速度很快但是长度固定的 array,第二种是可动态增长的但是有性能损耗的 Vector,方便起见,我们称 array 为数组,Vector 为动态数组。
这两个数组的关系跟 &str 与 String 的关系很像,前者是长度固定的字符串切片,后者是可动态增长的字符串。其实,在 Rust 中无论是 String 还是 Vector,它们都是 Rust 的高级类型:集合类型,在后面章节会有详细介绍。
对于本章节,我们的重点还是放在数组 array 上。数组的具体定义很简单:将多个类型相同的元素依次组合在一起,就是一个数组。结合上面的内容,可以得出数组的三要素:
- 长度固定
- 元素必须有相同的类型
- 依次线性排列
这里再啰嗦一句,我们这里说的数组是 Rust 的基本类型,是固定长度的,这点与其他编程语言不同,其它编程语言的数组往往是可变长度的,与 Rust 中的动态数组 Vector 类似。
创建数组
在 Rust 中,数组是这样定义的:
fn main() { let a = [1, 2, 3, 4, 5]; }
数组语法跟 JavaScript 很像,也跟大多数编程语言很像。由于它的元素类型大小固定,且长度也是固定,因此数组 array 是存储在栈上,性能也会非常优秀。与此对应,动态数组 Vector 是存储在堆上,因此长度可以动态改变。当你不确定是使用数组还是动态数组时,那就应该使用后者。
举个例子,在需要知道一年中各个月份名称的程序中,你很可能希望使用的是数组而不是动态数组。因为月份是固定的,它总是只包含 12 个元素:
#![allow(unused)] fn main() { let months = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]; }
在一些时候,还需要为数组声明类型,如下所示:
#![allow(unused)] fn main() { let a: [i32; 5] = [1, 2, 3, 4, 5]; }
这里,数组类型是通过方括号语法声明,i32 是元素类型,分号后面的数字 5 是数组长度,数组类型也从侧面说明了数组的元素类型要统一,长度要固定。
还可以使用下面的语法初始化一个某个值重复出现 N 次的数组:
#![allow(unused)] fn main() { let a = [3; 5]; }
a 数组包含 5 个元素,这些元素的初始化值为 3,聪明的读者已经发现,这种语法跟数组类型的声明语法其实是保持一致的:[3; 5] 和 [类型; 长度]。
访问数组元素
因为数组是连续存放元素的,因此可以通过索引的方式来访问存放其中的元素:
fn main() { let a = [9, 8, 7, 6, 5]; let first = a[0]; // 获取a数组第一个元素 let second = a[1]; // 获取第二个元素 }
与许多语言类似,数组的索引下标是从 0 开始的。此处,first 获取到的值是 9,second 是 8。
越界访问
如果使用超出数组范围的索引访问数组元素,会怎么样?下面是一个接收用户的控制台输入,然后将其作为索引访问数组元素的例子:
use std::io; fn main() { let a = [1, 2, 3, 4, 5]; println!("Please enter an array index."); let mut index = String::new(); // 读取控制台的输出 io::stdin() .read_line(&mut index) .expect("Failed to read line"); let index: usize = index .trim() .parse() .expect("Index entered was not a number"); let element = a[index]; println!( "The value of the element at index {} is: {}", index, element ); }
使用 cargo run 来运行代码,因为数组只有 5 个元素,如果我们试图输入 5 去访问第 6 个元素,则会访问到不存在的数组元素,最终程序会崩溃退出:
Please enter an array index.
5
thread 'main' panicked at 'index out of bounds: the len is 5 but the index is 5', src/main.rs:19:19
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
这就是数组访问越界,访问了数组中不存在的元素,导致 Rust 运行时错误。程序因此退出并显示错误消息,未执行最后的 println! 语句。
当你尝试使用索引访问元素时,Rust 将检查你指定的索引是否小于数组长度。如果索引大于或等于数组长度,Rust 会出现 panic。这种检查只能在运行时进行,比如在上面这种情况下,编译器无法在编译期知道用户运行代码时将输入什么值。
这种就是 Rust 的安全特性之一。在很多系统编程语言中,并不会检查数组越界问题,你会访问到无效的内存地址获取到一个风马牛不相及的值,最终导致在程序逻辑上出现大问题,而且这种问题会非常难以检查。
数组元素为非基础类型
学习了上面的知识,很多朋友肯定觉得已经学会了Rust的数组类型,但现实会给我们一记重锤,实际开发中还会碰到一种情况,就是数组元素是非基本类型的,这时候大家一定会这样写。
#![allow(unused)] fn main() { let array = [String::from("rust is good!"); 8]; println!("{:#?}", array); }
然后你会惊喜的得到编译错误。
error[E0277]: the trait bound `String: std::marker::Copy` is not satisfied
--> src/main.rs:7:18
|
7 | let array = [String::from("rust is good!"); 8];
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `std::marker::Copy` is not implemented for `String`
|
= note: the `Copy` trait is required because this value will be copied for each element of the array
有些还没有看过特征的小伙伴,有可能不太明白这个报错,不过这个目前可以不提,我们就拿之前所学的所有权知识,就可以思考明白,前面几个例子都是Rust的基本类型,而基本类型在Rust中赋值是以Copy的形式,let array=[3;5]底层就是不断的Copy出来的,但很可惜复杂类型都没有深拷贝,只能一个个创建。
正确的写法,应该调用std::array::from_fn
#![allow(unused)] fn main() { let array: [String; 8] = core::array::from_fn(|i| String::from("rust is good!")); println!("{:#?}", array); }
数组切片
在之前的章节,我们有讲到 切片 这个概念,它允许你引用集合中的部分连续片段,而不是整个集合,对于数组也是,数组切片允许我们引用数组的一部分:
#![allow(unused)] fn main() { let a: [i32; 5] = [1, 2, 3, 4, 5]; let slice: &[i32] = &a[1..3]; assert_eq!(slice, &[2, 3]); }
上面的数组切片 slice 的类型是&[i32],与之对比,数组的类型是[i32;5],简单总结下切片的特点:
- 切片的长度可以与数组不同,并不是固定的,而是取决于你使用时指定的起始和结束位置
- 创建切片的代价非常小,因为切片只是针对底层数组的一个引用
- 切片类型[T]拥有不固定的大小,而切片引用类型&[T]则具有固定的大小,因为 Rust 很多时候都需要固定大小数据类型,因此&[T]更有用,
&str字符串切片也同理
总结
最后,让我们以一个综合性使用数组的例子来结束本章:
fn main() { // 编译器自动推导出one的类型 let one = [1, 2, 3]; // 显式类型标注 let two: [u8; 3] = [1, 2, 3]; let blank1 = [0; 3]; let blank2: [u8; 3] = [0; 3]; // arrays是一个二维数组,其中每一个元素都是一个数组,元素类型是[u8; 3] let arrays: [[u8; 3]; 4] = [one, two, blank1, blank2]; // 借用arrays的元素用作循环中 for a in &arrays { print!("{:?}: ", a); // 将a变成一个迭代器,用于循环 // 你也可以直接用for n in a {}来进行循环 for n in a.iter() { print!("\t{} + 10 = {}", n, n+10); } let mut sum = 0; // 0..a.len,是一个 Rust 的语法糖,其实就等于一个数组,元素是从0,1,2一直增加到到a.len-1 for i in 0..a.len() { sum += a[i]; } println!("\t({:?} = {})", a, sum); } }
做个总结,数组虽然很简单,但是其实还是存在几个要注意的点:
- 数组类型容易跟数组切片混淆,[T;n]描述了一个数组的类型,而[T]描述了切片的类型, 因为切片是运行期的数据结构,它的长度无法在编译期得知,因此不能用[T;n]的形式去描述
[u8; 3]和[u8; 4]是不同的类型,数组的长度也是类型的一部分- 在实际开发中,使用最多的是数组切片[T],我们往往通过引用的方式去使用
&[T],因为后者有固定的类型大小
流程控制
Rust 程序是从上而下顺序执行的,在此过程中,我们可以通过循环、分支等流程控制方式,更好的实现相应的功能。
使用 if 来做分支控制
if else 无处不在 -- 鲁迅
只要你拥有其它语言的编程经验,就一定会有以下认知:if else 表达式根据条件执行不同的代码分支:
#![allow(unused)] fn main() { if condition == true { // A... } else { // B... } }
该代码读作:若 condition 的值为 true,则执行 A 代码,否则执行 B 代码。
先看下面代码:
fn main() { let condition = true; let number = if condition { 5 } else { 6 }; println!("The value of number is: {}", number); }
以上代码有以下几点要注意:
if语句块是表达式,这里我们使用if表达式的返回值来给number进行赋值:number的值是5- 用
if来赋值时,要保证每个分支返回的类型一样(事实上,这种说法不完全准确,见这里),此处返回的5和6就是同一个类型,如果返回类型不一致就会报错
error[E0308]: if and else have incompatible types
--> src/main.rs:4:18
|
4 | let number = if condition {
| __________________^
5 | | 5
6 | | } else {
7 | | "six"
8 | | };
| |_____^ expected integer, found &str // 期望整数类型,但却发现&str字符串切片
|
= note: expected type `{integer}`
found type `&str`
使用 else if 来处理多重条件
可以将 else if 与 if、else 组合在一起实现更复杂的条件分支判断:
fn main() { let n = 6; if n % 4 == 0 { println!("number is divisible by 4"); } else if n % 3 == 0 { println!("number is divisible by 3"); } else if n % 2 == 0 { println!("number is divisible by 2"); } else { println!("number is not divisible by 4, 3, or 2"); } }
程序执行时,会按照自上至下的顺序执行每一个分支判断,一旦成功,则跳出 if 语句块,最终本程序会匹配执行 else if n % 3 == 0 的分支,输出 "number is divisible by 3"。
有一点要注意,就算有多个分支能匹配,也只有第一个匹配的分支会被执行!
如果代码中有大量的 else if 会让代码变得极其丑陋,不过不用担心,下一章的 match 专门用以解决多分支模式匹配的问题。
循环控制
在 Rust 语言中有三种循环方式:for、while 和 loop,其中 for 循环是 Rust 循环王冠上的明珠。
for 循环
for 循环是 Rust 的大杀器:
fn main() { for i in 1..=5 { println!("{}", i); } }
以上代码循环输出一个从 1 到 5 的序列,简单粗暴,核心就在于 for 和 in 的联动,语义表达如下:
#![allow(unused)] fn main() { for 元素 in 集合 { // 使用元素干一些你懂我不懂的事情 } }
注意,使用 for 时我们往往使用集合的引用形式,除非你不想在后面的代码中继续使用该集合(比如我们这里使用了 container 的引用)。如果不使用引用的话,所有权会被转移(move)到 for 语句块中,后面就无法再使用这个集合了):
#![allow(unused)] fn main() { for item in &container { // ... } }
对于实现了
copy特征的数组(例如 [i32; 10] )而言,for item in arr并不会把arr的所有权转移,而是直接对其进行了拷贝,因此循环之后仍然可以使用arr。
如果想在循环中,修改该元素,可以使用 mut 关键字:
#![allow(unused)] fn main() { for item in &mut collection { // ... } }
总结如下:
| 使用方法 | 等价使用方式 | 所有权 |
|---|---|---|
for item in collection | for item in IntoIterator::into_iter(collection) | 转移所有权 |
for item in &collection | for item in collection.iter() | 不可变借用 |
for item in &mut collection | for item in collection.iter_mut() | 可变借用 |
如果想在循环中获取元素的索引:
fn main() { let a = [4, 3, 2, 1]; // `.iter()` 方法把 `a` 数组变成一个迭代器 for (i, v) in a.iter().enumerate() { println!("第{}个元素是{}", i + 1, v); } }
有同学可能会想到,如果我们想用 for 循环控制某个过程执行 10 次,但是又不想单独声明一个变量来控制这个流程,该怎么写?
#![allow(unused)] fn main() { for _ in 0..10 { // ... } }
可以用 _ 来替代 i 用于 for 循环中,在 Rust 中 _ 的含义是忽略该值或者类型的意思,如果不使用 _,那么编译器会给你一个 变量未使用的 的警告。
两种循环方式优劣对比
以下代码,使用了两种循环方式:
#![allow(unused)] fn main() { // 第一种 let collection = [1, 2, 3, 4, 5]; for i in 0..collection.len() { let item = collection[i]; // ... } // 第二种 for item in collection { } }
第一种方式是循环索引,然后通过索引下标去访问集合,第二种方式是直接循环集合中的元素,优劣如下:
- 性能:第一种使用方式中
collection[index]的索引访问,会因为边界检查(Bounds Checking)导致运行时的性能损耗 —— Rust 会检查并确认index是否落在集合内,但是第二种直接迭代的方式就不会触发这种检查,因为编译器会在编译时就完成分析并证明这种访问是合法的 - 安全:第一种方式里对
collection的索引访问是非连续的,存在一定可能性在两次访问之间,collection发生了变化,导致脏数据产生。而第二种直接迭代的方式是连续访问,因此不存在这种风险(这里是因为所有权吗?是的话可能要强调一下)
由于 for 循环无需任何条件限制,也不需要通过索引来访问,因此是最安全也是最常用的,通过与下面的 while 的对比,我们能看到为什么 for 会更加安全。
continue
使用 continue 可以跳过当前当次的循环,开始下次的循环:
#![allow(unused)] fn main() { for i in 1..4 { if i == 2 { continue; } println!("{}", i); } }
上面代码对 1 到 3 的序列进行迭代,且跳过值为 2 时的循环,输出如下:
1
3
break
使用 break 可以直接跳出当前整个循环:
#![allow(unused)] fn main() { for i in 1..4 { if i == 2 { break; } println!("{}", i); } }
上面代码对 1 到 3 的序列进行迭代,在遇到值为 2 时的跳出整个循环,后面的循环不再执行,输出如下:
1
while 循环
如果你需要一个条件来循环,当该条件为 true 时,继续循环,条件为 false,跳出循环,那么 while 就非常适用:
fn main() { let mut n = 0; while n <= 5 { println!("{}!", n); n = n + 1; } println!("我出来了!"); }
该 while 循环,只有当 n 小于等于 5 时,才执行,否则就立刻跳出循环,因此在上述代码中,它会先从 0 开始,满足条件,进行循环,然后是 1,满足条件,进行循环,最终到 6 的时候,大于 5,不满足条件,跳出 while 循环,执行 我出来了 的打印,然后程序结束:
0!
1!
2!
3!
4!
5!
我出来了!
当然,你也可以用其它方式组合实现,例如 loop(无条件循环,将在下面介绍) + if + break:
fn main() { let mut n = 0; loop { if n > 5 { break } println!("{}", n); n+=1; } println!("我出来了!"); }
可以看出,在这种循环场景下,while 要简洁的多。
while vs for
我们也能用 while 来实现 for 的功能:
fn main() { let a = [10, 20, 30, 40, 50]; let mut index = 0; while index < 5 { println!("the value is: {}", a[index]); index = index + 1; } }
这里,代码对数组中的元素进行计数。它从索引 0 开始,并接着循环直到遇到数组的最后一个索引(这时,index < 5 不再为真)。运行这段代码会打印出数组中的每一个元素:
the value is: 10
the value is: 20
the value is: 30
the value is: 40
the value is: 50
数组中的所有五个元素都如期被打印出来。尽管 index 在某一时刻会到达值 5,不过循环在其尝试从数组获取第六个值(会越界)之前就停止了。
但这个过程很容易出错;如果索引长度不正确会导致程序 panic。这也使程序更慢,因为编译器增加了运行时代码来对每次循环的每个元素进行条件检查。
for循环代码如下:
fn main() { let a = [10, 20, 30, 40, 50]; for element in a.iter() { println!("the value is: {}", element); } }
可以看出,for 并不会使用索引去访问数组,因此更安全也更简洁,同时避免 运行时的边界检查,性能更高。
loop 循环
对于循环而言,loop 循环毋庸置疑,是适用面最高的,它可以适用于所有循环场景(虽然能用,但是在很多场景下, for 和 while 才是最优选择),因为 loop 就是一个简单的无限循环,你可以在内部实现逻辑通过 break 关键字来控制循环何时结束。
当使用 loop 时,必不可少的伙伴是 break 关键字,它能让循环在满足某个条件时跳出:
fn main() { let mut counter = 0; let result = loop { counter += 1; if counter == 10 { break counter * 2; } }; println!("The result is {}", result); }
以上代码当 counter 递增到 10 时,就会通过 break 返回一个 counter * 2 的值,最后赋给 result 并打印出来。
这里有几点值得注意:
- break 可以单独使用,也可以带一个返回值,有些类似
return - loop 是一个表达式,因此可以返回一个值
模式匹配
match 和 if let
在 Rust 中,模式匹配最常用的就是 match 和 if let,本章节将对两者及相关的概念进行详尽介绍。
先来看一个关于 match 的简单例子:
enum Direction { East, West, North, South, } fn main() { let dire = Direction::South; match dire { Direction::East => println!("East"), Direction::North | Direction::South => { println!("South or North"); }, _ => println!("West"), }; }
这里我们想去匹配 dire 对应的枚举类型,因此在 match 中用三个匹配分支来完全覆盖枚举变量 Direction 的所有成员类型,有以下几点值得注意:
match的匹配必须要穷举出所有可能,因此这里用_来代表未列出的所有可能性match的每一个分支都必须是一个表达式,且所有分支的表达式最终返回值的类型必须相同- X | Y,类似逻辑运算符
或,代表该分支可以匹配X也可以匹配Y,只要满足一个即可
其实 match 跟其他语言中的 switch 非常像,_ 类似于 switch 中的 default。
match 匹配
首先来看看 match 的通用形式:
#![allow(unused)] fn main() { match target { 模式1 => 表达式1, 模式2 => { 语句1; 语句2; 表达式2 }, _ => 表达式3 } }
该形式清晰的说明了何为模式,何为模式匹配:将模式与 target 进行匹配,即为模式匹配,而模式匹配不仅仅局限于 match,后面我们会详细阐述。
match 允许我们将一个值与一系列的模式相比较,并根据相匹配的模式执行对应的代码,下面让我们来一一详解,先看一个例子:
#![allow(unused)] fn main() { enum Coin { Penny, Nickel, Dime, Quarter, } fn value_in_cents(coin: Coin) -> u8 { match coin { Coin::Penny => { println!("Lucky penny!"); 1 }, Coin::Nickel => 5, Coin::Dime => 10, Coin::Quarter => 25, } } }
value_in_cents 函数根据匹配到的硬币,返回对应的美分数值。match 后紧跟着的是一个表达式,跟 if 很像,但是 if 后的表达式必须是一个布尔值,而 match 后的表达式返回值可以是任意类型,只要能跟后面的分支中的模式匹配起来即可,这里的 coin 是枚举 Coin 类型。
接下来是 match 的分支。一个分支有两个部分:一个模式和针对该模式的处理代码。第一个分支的模式是 Coin::Penny,其后的 => 运算符将模式和将要运行的代码分开。这里的代码就仅仅是表达式 1,不同分支之间使用逗号分隔。
当 match 表达式执行时,它将目标值 coin 按顺序依次与每一个分支的模式相比较,如果模式匹配了这个值,那么模式之后的代码将被执行。如果模式并不匹配这个值,将继续执行下一个分支。
每个分支相关联的代码是一个表达式,而表达式的结果值将作为整个 match 表达式的返回值。如果分支有多行代码,那么需要用 {} 包裹,同时最后一行代码需要是一个表达式。
使用 match 表达式赋值
还有一点很重要,match 本身也是一个表达式,因此可以用它来赋值:
enum IpAddr { Ipv4, Ipv6 } fn main() { let ip1 = IpAddr::Ipv6; let ip_str = match ip1 { IpAddr::Ipv4 => "127.0.0.1", _ => "::1", }; println!("{}", ip_str); }
因为这里匹配到 _ 分支,所以将 "::1" 赋值给了 ip_str。
模式绑定
模式匹配的另外一个重要功能是从模式中取出绑定的值,例如:
#![allow(unused)] fn main() { #[derive(Debug)] enum UsState { Alabama, Alaska, // --snip-- } enum Coin { Penny, Nickel, Dime, Quarter(UsState), // 25美分硬币 } }
其中 Coin::Quarter 成员还存放了一个值:美国的某个州(因为在 1999 年到 2008 年间,美国在 25 美分(Quarter)硬币的背后为 50 个州印刷了不同的标记,其它硬币都没有这样的设计)。
接下来,我们希望在模式匹配中,获取到 25 美分硬币上刻印的州的名称:
#![allow(unused)] fn main() { fn value_in_cents(coin: Coin) -> u8 { match coin { Coin::Penny => 1, Coin::Nickel => 5, Coin::Dime => 10, Coin::Quarter(state) => { println!("State quarter from {:?}!", state); 25 }, } } }
上面代码中,在匹配 Coin::Quarter(state) 模式时,我们把它内部存储的值绑定到了 state 变量上,因此 state 变量就是对应的 UsState 枚举类型。
例如有一个印了阿拉斯加州标记的 25 分硬币:Coin::Quarter(UsState::Alaska), 它在匹配时,state 变量将被绑定 UsState::Alaska 的枚举值。
再来看一个更复杂的例子:
enum Action { Say(String), MoveTo(i32, i32), ChangeColorRGB(u16, u16, u16), } fn main() { let actions = [ Action::Say("Hello Rust".to_string()), Action::MoveTo(1,2), Action::ChangeColorRGB(255,255,0), ]; for action in actions { match action { Action::Say(s) => { println!("{}", s); }, Action::MoveTo(x, y) => { println!("point from (0, 0) move to ({}, {})", x, y); }, Action::ChangeColorRGB(r, g, _) => { println!("change color into '(r:{}, g:{}, b:0)', 'b' has been ignored", r, g, ); } } } }
运行后输出:
$ cargo run
Compiling world_hello v0.1.0 (/Users/sunfei/development/rust/world_hello)
Finished dev [unoptimized + debuginfo] target(s) in 0.16s
Running `target/debug/world_hello`
Hello Rust
point from (0, 0) move to (1, 2)
change color into '(r:255, g:255, b:0)', 'b' has been ignored
穷尽匹配
在文章的开头,我们简单总结过 match 的匹配必须穷尽所有情况,下面来举例说明,例如:
enum Direction { East, West, North, South, } fn main() { let dire = Direction::South; match dire { Direction::East => println!("East"), Direction::North | Direction::South => { println!("South or North"); }, }; }
我们没有处理 Direction::West 的情况,因此会报错:
error[E0004]: non-exhaustive patterns: `West` not covered // 非穷尽匹配,`West` 没有被覆盖
--> src/main.rs:10:11
|
1 | / enum Direction {
2 | | East,
3 | | West,
| | ---- not covered
4 | | North,
5 | | South,
6 | | }
| |_- `Direction` defined here
...
10 | match dire {
| ^^^^ pattern `West` not covered // 模式 `West` 没有被覆盖
|
= help: ensure that all possible cases are being handled, possibly by adding wildcards or more match arms
= note: the matched value is of type `Direction`
不禁想感叹,Rust 的编译器真强大,忍不住想爆粗口了,sorry,如果你以后进一步深入使用 Rust 也会像我这样感叹的。Rust 编译器清晰地知道 match 中有哪些分支没有被覆盖, 这种行为能强制我们处理所有的可能性,有效避免传说中价值十亿美金的 null 陷阱。
_ 通配符
当我们不想在匹配时列出所有值的时候,可以使用 Rust 提供的一个特殊模式,例如,u8 可以拥有 0 到 255 的有效的值,但是我们只关心 1、3、5 和 7 这几个值,不想列出其它的 0、2、4、6、8、9 一直到 255 的值。那么, 我们不必一个一个列出所有值, 因为可以使用特殊的模式 _ 替代:
#![allow(unused)] fn main() { let some_u8_value = 0u8; match some_u8_value { 1 => println!("one"), 3 => println!("three"), 5 => println!("five"), 7 => println!("seven"), _ => (), } }
通过将 _ 其放置于其他分支后,_ 将会匹配所有遗漏的值。() 表示返回单元类型与所有分支返回值的类型相同,所以当匹配到 _ 后,什么也不会发生。
然而,在某些场景下,我们其实只关心某一个值是否存在,此时 match 就显得过于啰嗦。
if let 匹配
有时会遇到只有一个模式的值需要被处理,其它值直接忽略的场景,如果用 match 来处理就要写成下面这样:
#![allow(unused)] fn main() { let v = Some(3u8); match v { Some(3) => println!("three"), _ => (), } }
我们只想要对 Some(3) 模式进行匹配, 不想处理任何其他 Some<u8> 值或 None 值。但是为了满足 match 表达式(穷尽性)的要求,写代码时必须在处理完这唯一的成员后加上 _ => (),这样会增加不少无用的代码。
俗话说“杀鸡焉用牛刀”,我们完全可以用 if let 的方式来实现:
#![allow(unused)] fn main() { if let Some(3) = v { println!("three"); } }
这两种匹配对于新手来说,可能有些难以抉择,但是只要记住一点就好:当你只要匹配一个条件,且忽略其他条件时就用 if let ,否则都用 match。
matches!宏
Rust 标准库中提供了一个非常实用的宏:matches!,它可以将一个表达式跟模式进行匹配,然后返回匹配的结果 true or false。
例如,有一个动态数组,里面存有以下枚举:
enum MyEnum { Foo, Bar } fn main() { let v = vec![MyEnum::Foo,MyEnum::Bar,MyEnum::Foo]; }
现在如果想对 v 进行过滤,只保留类型是 MyEnum::Foo 的元素,你可能想这么写:
#![allow(unused)] fn main() { v.iter().filter(|x| x == MyEnum::Foo); }
但是,实际上这行代码会报错,因为你无法将 x 直接跟一个枚举成员进行比较。好在,你可以使用 match 来完成,但是会导致代码更为啰嗦,是否有更简洁的方式?答案是使用 matches!:
#![allow(unused)] fn main() { v.iter().filter(|x| matches!(x, MyEnum::Foo)); }
很简单也很简洁,再来看看更多的例子:
#![allow(unused)] fn main() { let foo = 'f'; assert!(matches!(foo, 'A'..='Z' | 'a'..='z')); let bar = Some(4); assert!(matches!(bar, Some(x) if x > 2)); }
变量覆盖
无论是 match 还是 if let,他们都可以在模式匹配时覆盖掉老的值,绑定新的值:
fn main() { let age = Some(30); println!("在匹配前,age是{:?}",age); if let Some(age) = age { println!("匹配出来的age是{}",age); } println!("在匹配后,age是{:?}",age); }
cargo run 运行后输出如下:
在匹配前,age是Some(30)
匹配出来的age是30
在匹配后,age是Some(30)
可以看出在 if let 中,= 右边 Some(i32) 类型的 age 被左边 i32 类型的新 age 覆盖了,该覆盖一直持续到 if let 语句块的结束。因此第三个 println! 输出的 age 依然是 Some(i32) 类型。
对于 match 类型也是如此:
fn main() { let age = Some(30); println!("在匹配前,age是{:?}",age); match age { Some(age) => println!("匹配出来的age是{}",age), _ => () } println!("在匹配后,age是{:?}",age); }
需要注意的是,match 中的变量覆盖其实不是那么的容易看出,因此要小心!
课后练习
Rust By Practice,支持代码在线编辑和运行,并提供详细的习题解答。
解构 Option
在枚举那章,提到过 Option 枚举,它用来解决 Rust 中变量是否有值的问题,定义如下:
#![allow(unused)] fn main() { enum Option<T> { Some(T), None, } }
简单解释就是:一个变量要么有值:Some(T), 要么为空:None。
那么现在的问题就是该如何去使用这个 Option 枚举类型,根据我们上一节的经验,可以通过 match 来实现。
因为
Option,Some,None都包含在prelude中,因此你可以直接通过名称来使用它们,而无需以Option::Some这种形式去使用,总之,千万不要因为调用路径变短了,就忘记Some和None也是Option底下的枚举成员!
匹配 Option<T>
使用 Option<T>,是为了从 Some 中取出其内部的 T 值以及处理没有值的情况,为了演示这一点,下面一起来编写一个函数,它获取一个 Option<i32>,如果其中含有一个值,将其加一;如果其中没有值,则函数返回 None 值:
#![allow(unused)] fn main() { fn plus_one(x: Option<i32>) -> Option<i32> { match x { None => None, Some(i) => Some(i + 1), } } let five = Some(5); let six = plus_one(five); let none = plus_one(None); }
plus_one 接受一个 Option<i32> 类型的参数,同时返回一个 Option<i32> 类型的值(这种形式的函数在标准库内随处所见),在该函数的内部处理中,如果传入的是一个 None ,则返回一个 None 且不做任何处理;如果传入的是一个 Some(i32),则通过模式绑定,把其中的值绑定到变量 i 上,然后返回 i+1 的值,同时用 Some 进行包裹。
为了进一步说明,假设 plus_one 函数接受的参数值 x 是 Some(5),来看看具体的分支匹配情况:
传入参数 Some(5)
None => None,
首先是匹配 None 分支,因为值 Some(5) 并不匹配模式 None,所以继续匹配下一个分支。
Some(i) => Some(i + 1),
Some(5) 与 Some(i) 匹配吗?当然匹配!它们是相同的成员。i 绑定了 Some 中包含的值,因此 i 的值是 5。接着匹配分支的代码被执行,最后将 i 的值加一并返回一个含有值 6 的新 Some。
传入参数 None
接着考虑下 plus_one 的第二个调用,这次传入的 x 是 None, 我们进入 match 并与第一个分支相比较。
None => None,
匹配上了!接着程序继续执行该分支后的代码:返回表达式 None 的值,也就是返回一个 None,因为第一个分支就匹配到了,其他的分支将不再比较。
模式适用场景
模式
模式是 Rust 中的特殊语法,它用来匹配类型中的结构和数据,它往往和 match 表达式联用,以实现强大的模式匹配能力。模式一般由以下内容组合而成:
- 字面值
- 解构的数组、枚举、结构体或者元组
- 变量
- 通配符
- 占位符
所有可能用到模式的地方
match 分支
#![allow(unused)] fn main() { match VALUE { PATTERN => EXPRESSION, PATTERN => EXPRESSION, PATTERN => EXPRESSION, } }
如上所示,match 的每个分支就是一个模式,因为 match 匹配是穷尽式的,因此我们往往需要一个特殊的模式 _,来匹配剩余的所有情况:
#![allow(unused)] fn main() { match VALUE { PATTERN => EXPRESSION, PATTERN => EXPRESSION, _ => EXPRESSION, } }
if let 分支
if let 往往用于匹配一个模式,而忽略剩下的所有模式的场景:
#![allow(unused)] fn main() { if let PATTERN = SOME_VALUE { } }
while let 条件循环
一个与 if let 类似的结构是 while let 条件循环,它允许只要模式匹配就一直进行 while 循环。下面展示了一个使用 while let 的例子:
#![allow(unused)] fn main() { // Vec是动态数组 let mut stack = Vec::new(); // 向数组尾部插入元素 stack.push(1); stack.push(2); stack.push(3); // stack.pop从数组尾部弹出元素 while let Some(top) = stack.pop() { println!("{}", top); } }
这个例子会打印出 3、2 接着是 1。pop 方法取出动态数组的最后一个元素并返回 Some(value),如果动态数组是空的,将返回 None,对于 while 来说,只要 pop 返回 Some 就会一直不停的循环。一旦其返回 None,while 循环停止。我们可以使用 while let 来弹出栈中的每一个元素。
你也可以用 loop + if let 或者 match 来实现这个功能,但是会更加啰嗦。
for 循环
#![allow(unused)] fn main() { let v = vec!['a', 'b', 'c']; for (index, value) in v.iter().enumerate() { println!("{} is at index {}", value, index); } }
这里使用 enumerate 方法产生一个迭代器,该迭代器每次迭代会返回一个 (索引,值) 形式的元组,然后用 (index,value) 来匹配。
let 语句
#![allow(unused)] fn main() { let PATTERN = EXPRESSION; }
是的, 该语句我们已经用了无数次了,它也是一种模式匹配:
#![allow(unused)] fn main() { let x = 5; }
这其中,x 也是一种模式绑定,代表将匹配的值绑定到变量 x 上。因此,在 Rust 中,变量名也是一种模式,只不过它比较朴素很不起眼罢了。
#![allow(unused)] fn main() { let (x, y, z) = (1, 2, 3); }
上面将一个元组与模式进行匹配(模式和值的类型必需相同!),然后把 1, 2, 3 分别绑定到 x, y, z 上。
模式匹配要求两边的类型必须相同,否则就会导致下面的报错:
#![allow(unused)] fn main() { let (x, y) = (1, 2, 3); }
#![allow(unused)] fn main() { error[E0308]: mismatched types --> src/main.rs:4:5 | 4 | let (x, y) = (1, 2, 3); | ^^^^^^ --------- this expression has type `({integer}, {integer}, {integer})` | | | expected a tuple with 3 elements, found one with 2 elements | = note: expected tuple `({integer}, {integer}, {integer})` found tuple `(_, _)` For more information about this error, try `rustc --explain E0308`. error: could not compile `playground` due to previous error }
对于元组来说,元素个数也是类型的一部分!
函数参数
函数参数也是模式:
#![allow(unused)] fn main() { fn foo(x: i32) { // 代码 } }
其中 x 就是一个模式,你还可以在参数中匹配元组:
fn print_coordinates(&(x, y): &(i32, i32)) { println!("Current location: ({}, {})", x, y); } fn main() { let point = (3, 5); print_coordinates(&point); }
&(3, 5) 会匹配模式 &(x, y),因此 x 得到了 3,y 得到了 5。
let 和 if let
对于以下代码,编译器会报错:
#![allow(unused)] fn main() { let Some(x) = some_option_value; }
因为右边的值可能不为 Some,而是 None,这种时候就不能进行匹配,也就是上面的代码遗漏了 None 的匹配。
类似 let , for和match 都必须要求完全覆盖匹配,才能通过编译( 不可驳模式匹配 )。
但是对于 if let,就可以这样使用:
#![allow(unused)] fn main() { if let Some(x) = some_option_value { println!("{}", x); } }
因为 if let 允许匹配一种模式,而忽略其余的模式( 可驳模式匹配 )。
全模式列表
我们已领略过许多不同类型模式的例子,本节的目标就是把这些模式语法都罗列出来,方便大家检索查阅(模式匹配在我们的开发中会经常用到)。
匹配字面值
#![allow(unused)] fn main() { let x = 1; match x { 1 => println!("one"), 2 => println!("two"), 3 => println!("three"), _ => println!("anything"), } }
这段代码会打印 one 因为 x 的值是 1,如果希望代码获得特定的具体值,那么这种语法很有用。
匹配命名变量
在 match 中,我们有讲过变量覆盖的问题,这个在匹配命名变量时会遇到:
fn main() { let x = Some(5); let y = 10; match x { Some(50) => println!("Got 50"), Some(y) => println!("Matched, y = {:?}", y), _ => println!("Default case, x = {:?}", x), } println!("at the end: x = {:?}, y = {:?}", x, y); }
让我们看看当 match 语句运行的时候发生了什么。第一个匹配分支的模式并不匹配 x 中定义的值,所以代码继续执行。
第二个匹配分支中的模式引入了一个新变量 y,它会匹配任何 Some 中的值。因为这里的 y 在 match 表达式的作用域中,而不是之前 main 作用域中,所以这是一个新变量,不是开头声明为值 10 的那个 y。这个新的 y 绑定会匹配任何 Some 中的值,在这里是 x 中的值。因此这个 y 绑定了 x 中 Some 内部的值。这个值是 5,所以这个分支的表达式将会执行并打印出 Matched,y = 5。
如果 x 的值是 None 而不是 Some(5),头两个分支的模式不会匹配,所以会匹配模式 _。这个分支的模式中没有引入变量 x,所以此时表达式中的 x 会是外部没有被覆盖的 x,也就是 None。
一旦 match 表达式执行完毕,其作用域也就结束了,同理内部 y 的作用域也结束了。最后的 println! 会打印 at the end: x = Some(5), y = 10。
如果你不想引入变量覆盖,那么需要使用匹配守卫(match guard)的方式,稍后在匹配守卫提供的额外条件中会讲解。
单分支多模式
在 match 表达式中,可以使用 | 语法匹配多个模式,它代表 或的意思。例如,如下代码将 x 的值与匹配分支相比较,第一个分支有 或 选项,意味着如果 x 的值匹配此分支的任何一个模式,它就会运行:
#![allow(unused)] fn main() { let x = 1; match x { 1 | 2 => println!("one or two"), 3 => println!("three"), _ => println!("anything"), } }
上面的代码会打印 one or two。
通过序列 ..= 匹配值的范围
在数值类型中我们有讲到一个序列语法,该语法不仅可以用于循环中,还能用于匹配模式。
..= 语法允许你匹配一个闭区间序列内的值。在如下代码中,当模式匹配任何在此序列内的值时,该分支会执行:
#![allow(unused)] fn main() { let x = 5; match x { 1..=5 => println!("one through five"), _ => println!("something else"), } }
如果 x 是 1、2、3、4 或 5,第一个分支就会匹配。这相比使用 | 运算符表达相同的意思更为方便;相比 1..=5,使用 | 则不得不指定 1 | 2 | 3 | 4 | 5 这五个值,而使用 ..= 指定序列就简短的多,比如希望匹配比如从 1 到 1000 的数字的时候!
序列只允许用于数字或字符类型,原因是:它们可以连续,同时编译器在编译期可以检查该序列是否为空,字符和数字值是 Rust 中仅有的可以用于判断是否为空的类型。
如下是一个使用字符类型序列的例子:
#![allow(unused)] fn main() { let x = 'c'; match x { 'a'..='j' => println!("early ASCII letter"), 'k'..='z' => println!("late ASCII letter"), _ => println!("something else"), } }
Rust 知道 'c' 位于第一个模式的序列内,所以会打印出 early ASCII letter。
解构并分解值
也可以使用模式来解构结构体、枚举、元组、数组和引用。
解构结构体
下面代码展示了如何用 let 解构一个带有两个字段 x 和 y 的结构体 Point:
struct Point { x: i32, y: i32, } fn main() { let p = Point { x: 0, y: 7 }; let Point { x: a, y: b } = p; assert_eq!(0, a); assert_eq!(7, b); }
这段代码创建了变量 a 和 b 来匹配结构体 p 中的 x 和 y 字段,这个例子展示了模式中的变量名不必与结构体中的字段名一致。不过通常希望变量名与字段名一致以便于理解变量来自于哪些字段。
因为变量名匹配字段名是常见的,同时因为 let Point { x: x, y: y } = p; 中 x 和 y 重复了,所以对于匹配结构体字段的模式存在简写:只需列出结构体字段的名称,则模式创建的变量会有相同的名称。下例与上例有着相同行为的代码,不过 let 模式创建的变量为 x 和 y 而不是 a 和 b:
struct Point { x: i32, y: i32, } fn main() { let p = Point { x: 0, y: 7 }; let Point { x, y } = p; assert_eq!(0, x); assert_eq!(7, y); }
这段代码创建了变量 x 和 y,与结构体 p 中的 x 和 y 字段相匹配。其结果是变量 x 和 y 包含结构体 p 中的值。
也可以使用字面值作为结构体模式的一部分进行解构,而不是为所有的字段创建变量。这允许我们测试一些字段为特定值的同时创建其他字段的变量。
下文展示了固定某个字段的匹配方式:
struct Point { x: i32, y: i32, } fn main() { let p = Point { x: 0, y: 7 }; match p { Point { x, y: 0 } => println!("On the x axis at {}", x), Point { x: 0, y } => println!("On the y axis at {}", y), Point { x, y } => println!("On neither axis: ({}, {})", x, y), } }
首先是 match 第一个分支,指定匹配 y 为 0 的 Point;
然后第二个分支在第一个分支之后,匹配 y 不为 0,x 为 0 的 Point;
最后一个分支匹配 x 不为 0,y 也不为 0 的 Point。
在这个例子中,值 p 因为其 x 包含 0 而匹配第二个分支,因此会打印出 On the y axis at 7。
解构枚举
下面代码以 Message 枚举为例,编写一个 match 使用模式解构每一个内部值:
enum Message { Quit, Move { x: i32, y: i32 }, Write(String), ChangeColor(i32, i32, i32), } fn main() { let msg = Message::ChangeColor(0, 160, 255); match msg { Message::Quit => { println!("The Quit variant has no data to destructure.") } Message::Move { x, y } => { println!( "Move in the x direction {} and in the y direction {}", x, y ); } Message::Write(text) => println!("Text message: {}", text), Message::ChangeColor(r, g, b) => { println!( "Change the color to red {}, green {}, and blue {}", r, g, b ) } } }
这里老生常谈一句话,模式匹配一样要类型相同,因此匹配 Message::Move{1,2} 这样的枚举值,就必须要用 Message::Move{x,y} 这样的同类型模式才行。
这段代码会打印出 Change the color to red 0, green 160, and blue 255。尝试改变 msg 的值来观察其他分支代码的运行。
对于像 Message::Quit 这样没有任何数据的枚举成员,不能进一步解构其值。只能匹配其字面值 Message::Quit,因此模式中没有任何变量。
对于另外两个枚举成员,就用相同类型的模式去匹配出对应的值即可。
解构嵌套的结构体和枚举
目前为止,所有的例子都只匹配了深度为一级的结构体或枚举。 match 也可以匹配嵌套的项!
例如使用下面的代码来同时支持 RGB 和 HSV 色彩模式:
enum Color { Rgb(i32, i32, i32), Hsv(i32, i32, i32), } enum Message { Quit, Move { x: i32, y: i32 }, Write(String), ChangeColor(Color), } fn main() { let msg = Message::ChangeColor(Color::Hsv(0, 160, 255)); match msg { Message::ChangeColor(Color::Rgb(r, g, b)) => { println!( "Change the color to red {}, green {}, and blue {}", r, g, b ) } Message::ChangeColor(Color::Hsv(h, s, v)) => { println!( "Change the color to hue {}, saturation {}, and value {}", h, s, v ) } _ => () } }
match 第一个分支的模式匹配一个 Message::ChangeColor 枚举成员,该枚举成员又包含了一个 Color::Rgb 的枚举成员,最终绑定了 3 个内部的 i32 值。第二个,就交给亲爱的读者来思考完成。
解构结构体和元组
我们甚至可以用复杂的方式来混合、匹配和嵌套解构模式。如下是一个复杂结构体的例子,其中结构体和元组嵌套在元组中,并将所有的原始类型解构出来:
#![allow(unused)] fn main() { struct Point { x: i32, y: i32, } let ((feet, inches), Point {x, y}) = ((3, 10), Point { x: 3, y: -10 }); }
这种将复杂类型分解匹配的方式,可以让我们单独得到感兴趣的某个值。
解构数组
对于数组,我们可以用类似元组的方式解构,分为两种情况:
定长数组
#![allow(unused)] fn main() { let arr: [u16; 2] = [114, 514]; let [x, y] = arr; assert_eq!(x, 114); assert_eq!(y, 514); }
不定长数组
#![allow(unused)] fn main() { let arr: &[u16] = &[114, 514]; if let [x, ..] = arr { assert_eq!(x, &114); } if let &[.., y] = arr { assert_eq!(y, 514); } let arr: &[u16] = &[]; assert!(matches!(arr, [..])); assert!(!matches!(arr, [x, ..])); }
忽略模式中的值
有时忽略模式中的一些值是很有用的,比如在 match 中的最后一个分支使用 _ 模式匹配所有剩余的值。 你也可以在另一个模式中使用 _ 模式,使用一个以下划线开始的名称,或者使用 .. 忽略所剩部分的值。
使用 _ 忽略整个值
虽然 _ 模式作为 match 表达式最后的分支特别有用,但是它的作用还不限于此。例如可以将其用于函数参数中:
fn foo(_: i32, y: i32) { println!("This code only uses the y parameter: {}", y); } fn main() { foo(3, 4); }
这段代码会完全忽略作为第一个参数传递的值 3,并会打印出 This code only uses the y parameter: 4。
大部分情况当你不再需要特定函数参数时,最好修改签名不再包含无用的参数。在一些情况下忽略函数参数会变得特别有用,比如实现特征时,当你需要特定类型签名但是函数实现并不需要某个参数时。此时编译器就不会警告说存在未使用的函数参数,就跟使用命名参数一样。
使用嵌套的 _ 忽略部分值
可以在一个模式内部使用 _ 忽略部分值:
#![allow(unused)] fn main() { let mut setting_value = Some(5); let new_setting_value = Some(10); match (setting_value, new_setting_value) { (Some(_), Some(_)) => { println!("Can't overwrite an existing customized value"); } _ => { setting_value = new_setting_value; } } println!("setting is {:?}", setting_value); }
这段代码会打印出 Can't overwrite an existing customized value 接着是 setting is Some(5)。
第一个匹配分支,我们不关心里面的值,只关心元组中两个元素的类型,因此对于 Some 中的值,直接进行忽略。
剩下的形如 (Some(_),None),(None, Some(_)), (None,None) 形式,都由第二个分支 _ 进行分配。
还可以在一个模式中的多处使用下划线来忽略特定值,如下所示,这里忽略了一个五元元组中的第二和第四个值:
#![allow(unused)] fn main() { let numbers = (2, 4, 8, 16, 32); match numbers { (first, _, third, _, fifth) => { println!("Some numbers: {}, {}, {}", first, third, fifth) }, } }
老生常谈:模式匹配一定要类型相同,因此匹配 numbers 元组的模式,也必须有五个值(元组中元素的数量也属于元组类型的一部分)。
这会打印出 Some numbers: 2, 8, 32, 值 4 和 16 会被忽略。
使用下划线开头忽略未使用的变量
如果你创建了一个变量却不在任何地方使用它,Rust 通常会给你一个警告,因为这可能会是个 BUG。但是有时创建一个不会被使用的变量是有用的,比如你正在设计原型或刚刚开始一个项目。这时你希望告诉 Rust 不要警告未使用的变量,为此可以用下划线作为变量名的开头:
fn main() { let _x = 5; let y = 10; }
这里得到了警告说未使用变量 y,至于 x 则没有警告。
注意, 只使用 _ 和使用以下划线开头的名称有些微妙的不同:比如 _x 仍会将值绑定到变量,而 _ 则完全不会绑定。
#![allow(unused)] fn main() { let s = Some(String::from("Hello!")); if let Some(_s) = s { println!("found a string"); } println!("{:?}", s); }
s 是一个拥有所有权的动态字符串,在上面代码中,我们会得到一个错误,因为 s 的值会被转移给 _s,在 println! 中再次使用 s 会报错:
error[E0382]: borrow of partially moved value: `s`
--> src/main.rs:8:22
|
4 | if let Some(_s) = s {
| -- value partially moved here
...
8 | println!("{:?}", s);
| ^ value borrowed here after partial move
只使用下划线本身,则并不会绑定值,因为 s 没有被移动进 _:
#![allow(unused)] fn main() { let s = Some(String::from("Hello!")); if let Some(_) = s { println!("found a string"); } println!("{:?}", s); }
用 .. 忽略剩余值
对于有多个部分的值,可以使用 .. 语法来只使用部分值而忽略其它值,这样也不用再为每一个被忽略的值都单独列出下划线。.. 模式会忽略模式中剩余的任何没有显式匹配的值部分。
#![allow(unused)] fn main() { struct Point { x: i32, y: i32, z: i32, } let origin = Point { x: 0, y: 0, z: 0 }; match origin { Point { x, .. } => println!("x is {}", x), } }
这里列出了 x 值,接着使用了 .. 模式来忽略其它字段,这样的写法要比一一列出其它字段,然后用 _ 忽略简洁的多。
还可以用 .. 来忽略元组中间的某些值:
fn main() { let numbers = (2, 4, 8, 16, 32); match numbers { (first, .., last) => { println!("Some numbers: {}, {}", first, last); }, } }
这里用 first 和 last 来匹配第一个和最后一个值。.. 将匹配并忽略中间的所有值。
然而使用 .. 必须是无歧义的。如果期望匹配和忽略的值是不明确的,Rust 会报错。下面代码展示了一个带有歧义的 .. 例子,因此不能编译:
fn main() { let numbers = (2, 4, 8, 16, 32); match numbers { (.., second, ..) => { println!("Some numbers: {}", second) }, } }
如果编译上面的例子,会得到下面的错误:
error: `..` can only be used once per tuple pattern // 每个元组模式只能使用一个 `..`
--> src/main.rs:5:22
|
5 | (.., second, ..) => {
| -- ^^ can only be used once per tuple pattern
| |
| previously used here // 上一次使用在这里
error: could not compile `world_hello` due to previous error ^^
Rust 无法判断,second 应该匹配 numbers 中的第几个元素,因此这里使用两个 .. 模式,是有很大歧义的!
匹配守卫提供的额外条件
匹配守卫(match guard)是一个位于 match 分支模式之后的额外 if 条件,它能为分支模式提供更进一步的匹配条件。
这个条件可以使用模式中创建的变量:
#![allow(unused)] fn main() { let num = Some(4); match num { Some(x) if x < 5 => println!("less than five: {}", x), Some(x) => println!("{}", x), None => (), } }
这个例子会打印出 less than five: 4。当 num 与模式中第一个分支匹配时,Some(4) 可以与 Some(x) 匹配,接着匹配守卫检查 x 值是否小于 5,因为 4 小于 5,所以第一个分支被选择。
相反如果 num 为 Some(10),因为 10 不小于 5 ,所以第一个分支的匹配守卫为假。接着 Rust 会前往第二个分支,因为这里没有匹配守卫所以会匹配任何 Some 成员。
模式中无法提供类如 if x < 5 的表达能力,我们可以通过匹配守卫的方式来实现。
在之前,我们提到可以使用匹配守卫来解决模式中变量覆盖的问题,那里 match 表达式的模式中新建了一个变量而不是使用 match 之外的同名变量。内部变量覆盖了外部变量,意味着此时不能够使用外部变量的值,下面代码展示了如何使用匹配守卫修复这个问题。
fn main() { let x = Some(5); let y = 10; match x { Some(50) => println!("Got 50"), Some(n) if n == y => println!("Matched, n = {}", n), _ => println!("Default case, x = {:?}", x), } println!("at the end: x = {:?}, y = {}", x, y); }
现在这会打印出 Default case, x = Some(5)。现在第二个匹配分支中的模式不会引入一个覆盖外部 y 的新变量 y,这意味着可以在匹配守卫中使用外部的 y。相比指定会覆盖外部 y 的模式 Some(y),这里指定为 Some(n)。此新建的变量 n 并没有覆盖任何值,因为 match 外部没有变量 n。
匹配守卫 if n == y 并不是一个模式所以没有引入新变量。这个 y 正是 外部的 y 而不是新的覆盖变量 y,这样就可以通过比较 n 和 y 来表达寻找一个与外部 y 相同的值的概念了。
也可以在匹配守卫中使用 或 运算符 | 来指定多个模式,同时匹配守卫的条件会作用于所有的模式。下面代码展示了匹配守卫与 | 的优先级。这个例子中看起来好像 if y 只作用于 6,但实际上匹配守卫 if y 作用于 4、5 和 6 ,在满足 x 属于 4 | 5 | 6 后才会判断 y 是否为 true:
#![allow(unused)] fn main() { let x = 4; let y = false; match x { 4 | 5 | 6 if y => println!("yes"), _ => println!("no"), } }
这个匹配条件表明此分支只匹配 x 值为 4、5 或 6 同时 y 为 true 的情况。
虽然在第一个分支中,x 匹配了模式 4 ,但是对于匹配守卫 if y 来说,因为 y 是 false,因此该守卫条件的值永远是 false,也意味着第一个分支永远无法被匹配。
下面的文字图解释了匹配守卫作用于多个模式时的优先级规则,第一张是正确的:
(4 | 5 | 6) if y => ...
而第二张图是错误的
4 | 5 | (6 if y) => ...
可以通过运行代码时的情况看出这一点:如果匹配守卫只作用于由 | 运算符指定的值列表的最后一个值,这个分支就会匹配且程序会打印出 yes。
@绑定
@(读作 at)运算符允许为一个字段绑定另外一个变量。下面例子中,我们希望测试 Message::Hello 的 id 字段是否位于 3..=7 范围内,同时也希望能将其值绑定到 id_variable 变量中以便此分支中相关的代码可以使用它。我们可以将 id_variable 命名为 id,与字段同名,不过出于示例的目的这里选择了不同的名称。
#![allow(unused)] fn main() { enum Message { Hello { id: i32 }, } let msg = Message::Hello { id: 5 }; match msg { Message::Hello { id: id_variable @ 3..=7 } => { println!("Found an id in range: {}", id_variable) }, Message::Hello { id: 10..=12 } => { println!("Found an id in another range") }, Message::Hello { id } => { println!("Found some other id: {}", id) }, } }
上例会打印出 Found an id in range: 5。通过在 3..=7 之前指定 id_variable @,我们捕获了任何匹配此范围的值并同时将该值绑定到变量 id_variable 上。
第二个分支只在模式中指定了一个范围,id 字段的值可以是 10、11 或 12,不过这个模式的代码并不知情也不能使用 id 字段中的值,因为没有将 id 值保存进一个变量。
最后一个分支指定了一个没有范围的变量,此时确实拥有可以用于分支代码的变量 id,因为这里使用了结构体字段简写语法。不过此分支中没有像头两个分支那样对 id 字段的值进行测试:任何值都会匹配此分支。
当你既想要限定分支范围,又想要使用分支的变量时,就可以用 @ 来绑定到一个新的变量上,实现想要的功能。
@前绑定后解构(Rust 1.56 新增)
使用 @ 还可以在绑定新变量的同时,对目标进行解构:
#[derive(Debug)] struct Point { x: i32, y: i32, } fn main() { // 绑定新变量 `p`,同时对 `Point` 进行解构 let p @ Point {x: px, y: py } = Point {x: 10, y: 23}; println!("x: {}, y: {}", px, py); println!("{:?}", p); let point = Point {x: 10, y: 5}; if let p @ Point {x: 10, y} = point { println!("x is 10 and y is {} in {:?}", y, p); } else { println!("x was not 10 :("); } }
@新特性(Rust 1.53 新增)
考虑下面一段代码:
fn main() { match 1 { num @ 1 | 2 => { println!("{}", num); } _ => {} } }
编译不通过,是因为 num 没有绑定到所有的模式上,只绑定了模式 1,你可能会试图通过这个方式来解决:
#![allow(unused)] fn main() { num @ (1 | 2) }
但是,如果你用的是 Rust 1.53 之前的版本,那这种写法会报错,因为编译器不支持。
至此,模式匹配的内容已经全部完结,复杂但是详尽,想要一次性全部记住属实不易,因此读者可以先留一个印象,等未来需要时,再来翻阅寻找具体的模式实现方式。
方法 Method
从面向对象语言过来的同学对于方法肯定不陌生,class 里面就充斥着方法的概念。在 Rust 中,方法的概念也大差不差,往往和对象成对出现:
#![allow(unused)] fn main() { object.method() }
例如读取一个文件写入缓冲区,如果用函数的写法 read(f, buffer),用方法的写法 f.read(buffer)。不过与其它语言 class 跟方法的联动使用不同(这里可能要修改下),Rust 的方法往往跟结构体、枚举、特征(Trait)一起使用,特征将在后面几章进行介绍。
定义方法
Rust 使用 impl 来定义方法,例如以下代码:
#![allow(unused)] fn main() { struct Circle { x: f64, y: f64, radius: f64, } impl Circle { // new是Circle的关联函数,因为它的第一个参数不是self,且new并不是关键字 // 这种方法往往用于初始化当前结构体的实例 fn new(x: f64, y: f64, radius: f64) -> Circle { Circle { x: x, y: y, radius: radius, } } // Circle的方法,&self表示借用当前的Circle结构体 fn area(&self) -> f64 { std::f64::consts::PI * (self.radius * self.radius) } } }
我们这里先不详细展开讲解,只是先建立对方法定义的大致印象。下面的图片将 Rust 方法定义与其它语言的方法定义做了对比:
可以看出,其它语言中所有定义都在 class 中,但是 Rust 的对象定义和方法定义是分离的,这种数据和使用分离的方式,会给予使用者极高的灵活度。
再来看一个例子:
#[derive(Debug)] struct Rectangle { width: u32, height: u32, } impl Rectangle { fn area(&self) -> u32 { self.width * self.height } } fn main() { let rect1 = Rectangle { width: 30, height: 50 }; println!( "The area of the rectangle is {} square pixels.", rect1.area() ); }
该例子定义了一个 Rectangle 结构体,并且在其上定义了一个 area 方法,用于计算该矩形的面积。
impl Rectangle {} 表示为 Rectangle 实现方法(impl 是实现 implementation 的缩写),这样的写法表明 impl 语句块中的一切都是跟 Rectangle 相关联的。
self、&self 和 &mut self
接下来的内容非常重要,请大家仔细看。在 area 的签名中,我们使用 &self 替代 rectangle: &Rectangle,&self 其实是 self: &Self 的简写(注意大小写)。在一个 impl 块内,Self 指代被实现方法的结构体类型,self 指代此类型的实例,换句话说,self 指代的是 Rectangle 结构体实例,这样的写法会让我们的代码简洁很多,而且非常便于理解:我们为哪个结构体实现方法,那么 self 就是指代哪个结构体的实例。
需要注意的是,self 依然有所有权的概念:
self表示Rectangle的所有权转移到该方法中,这种形式用的较少&self表示该方法对Rectangle的不可变借用&mut self表示可变借用
总之,self 的使用就跟函数参数一样,要严格遵守 Rust 的所有权规则。
回到上面的例子中,选择 &self 的理由跟在函数中使用 &Rectangle 是相同的:我们并不想获取所有权,也无需去改变它,只是希望能够读取结构体中的数据。如果想要在方法中去改变当前的结构体,需要将第一个参数改为 &mut self。仅仅通过使用 self 作为第一个参数来使方法获取实例的所有权是很少见的,这种使用方式往往用于把当前的对象转成另外一个对象时使用,转换完后,就不再关注之前的对象,且可以防止对之前对象的误调用。
简单总结下,使用方法代替函数有以下好处:
- 不用在函数签名中重复书写
self对应的类型 - 代码的组织性和内聚性更强,对于代码维护和阅读来说,好处巨大
方法名跟结构体字段名相同
在 Rust 中,允许方法名跟结构体的字段名相同:
impl Rectangle { fn width(&self) -> bool { self.width > 0 } } fn main() { let rect1 = Rectangle { width: 30, height: 50, }; if rect1.width() { println!("The rectangle has a nonzero width; it is {}", rect1.width); } }
当我们使用 rect1.width() 时,Rust 知道我们调用的是它的方法,如果使用 rect1.width,则是访问它的字段。
一般来说,方法跟字段同名,往往适用于实现 getter 访问器,例如:
pub struct Rectangle { width: u32, height: u32, } impl Rectangle { pub fn new(width: u32, height: u32) -> Self { Rectangle { width, height } } pub fn width(&self) -> u32 { return self.width; } } fn main() { let rect1 = Rectangle::new(30, 50); println!("{}", rect1.width()); }
用这种方式,我们可以把 Rectangle 的字段设置为私有属性,只需把它的 new 和 width 方法设置为公开可见,那么用户就可以创建一个矩形,同时通过访问器 rect1.width() 方法来获取矩形的宽度,因为 width 字段是私有的,当用户访问 rect1.width 字段时,就会报错。注意在此例中,Self 指代的就是被实现方法的结构体 Rectangle。
->运算符到哪去了?在 C/C++ 语言中,有两个不同的运算符来调用方法:
.直接在对象上调用方法,而->在一个对象的指针上调用方法,这时需要先解引用指针。换句话说,如果object是一个指针,那么object->something()和(*object).something()是一样的。Rust 并没有一个与
->等效的运算符;相反,Rust 有一个叫 自动引用和解引用的功能。方法调用是 Rust 中少数几个拥有这种行为的地方。他是这样工作的:当使用
object.something()调用方法时,Rust 会自动为object添加&、&mut或*以便使object与方法签名匹配。也就是说,这些代码是等价的:#![allow(unused)] fn main() { #[derive(Debug,Copy,Clone)] struct Point { x: f64, y: f64, } impl Point { fn distance(&self, other: &Point) -> f64 { let x_squared = f64::powi(other.x - self.x, 2); let y_squared = f64::powi(other.y - self.y, 2); f64::sqrt(x_squared + y_squared) } } let p1 = Point { x: 0.0, y: 0.0 }; let p2 = Point { x: 5.0, y: 6.5 }; p1.distance(&p2); (&p1).distance(&p2); }第一行看起来简洁的多。这种自动引用的行为之所以有效,是因为方法有一个明确的接收者————
self的类型。在给出接收者和方法名的前提下,Rust 可以明确地计算出方法是仅仅读取(&self),做出修改(&mut self)或者是获取所有权(self)。事实上,Rust 对方法接收者的隐式借用让所有权在实践中更友好。
带有多个参数的方法
方法和函数一样,可以使用多个参数:
impl Rectangle { fn area(&self) -> u32 { self.width * self.height } fn can_hold(&self, other: &Rectangle) -> bool { self.width > other.width && self.height > other.height } } fn main() { let rect1 = Rectangle { width: 30, height: 50 }; let rect2 = Rectangle { width: 10, height: 40 }; let rect3 = Rectangle { width: 60, height: 45 }; println!("Can rect1 hold rect2? {}", rect1.can_hold(&rect2)); println!("Can rect1 hold rect3? {}", rect1.can_hold(&rect3)); }
关联函数
现在大家可以思考一个问题,如何为一个结构体定义一个构造器方法?也就是接受几个参数,然后构造并返回该结构体的实例。其实答案在开头的代码片段中就给出了,很简单,参数中不包含 self 即可。
这种定义在 impl 中且没有 self 的函数被称之为关联函数: 因为它没有 self,不能用 f.read() 的形式调用,因此它是一个函数而不是方法,它又在 impl 中,与结构体紧密关联,因此称为关联函数。
在之前的代码中,我们已经多次使用过关联函数,例如 String::from,用于创建一个动态字符串。
#![allow(unused)] fn main() { #[derive(Debug)] struct Rectangle { width: u32, height: u32, } impl Rectangle { fn new(w: u32, h: u32) -> Rectangle { Rectangle { width: w, height: h } } } }
Rust 中有一个约定俗成的规则,使用
new来作为构造器的名称,出于设计上的考虑,Rust 特地没有用new作为关键字
因为是函数,所以不能用 . 的方式来调用,我们需要用 :: 来调用,例如 let sq = Rectangle::new(3, 3);。这个方法位于结构体的命名空间中::: 语法用于关联函数和模块创建的命名空间。
多个 impl 定义
Rust 允许我们为一个结构体定义多个 impl 块,目的是提供更多的灵活性和代码组织性,例如当方法多了后,可以把相关的方法组织在同一个 impl 块中,那么就可以形成多个 impl 块,各自完成一块儿目标:
#![allow(unused)] fn main() { #[derive(Debug)] struct Rectangle { width: u32, height: u32, } impl Rectangle { fn area(&self) -> u32 { self.width * self.height } } impl Rectangle { fn can_hold(&self, other: &Rectangle) -> bool { self.width > other.width && self.height > other.height } } }
当然,就这个例子而言,我们没必要使用两个 impl 块,这里只是为了演示方便。
为枚举实现方法
枚举类型之所以强大,不仅仅在于它好用、可以同一化类型,还在于,我们可以像结构体一样,为枚举实现方法:
#![allow(unused)] enum Message { Quit, Move { x: i32, y: i32 }, Write(String), ChangeColor(i32, i32, i32), } impl Message { fn call(&self) { // 在这里定义方法体 } } fn main() { let m = Message::Write(String::from("hello")); m.call(); }
除了结构体和枚举,我们还能为特征(trait)实现方法,这将在下一章进行讲解,在此之前,先来看看泛型。
泛型与特征
泛型 Generics
Go 语言在 2022 年,就要正式引入泛型,被视为在 1.0 版本后,语言特性发展迈出的一大步,为什么泛型这么重要?到底什么是泛型?Rust 的泛型有几种?
我们在编程中,经常有这样的需求:用同一功能的函数处理不同类型的数据,例如两个数的加法,无论是整数还是浮点数,甚至是自定义类型,都能进行支持。在不支持泛型的编程语言中,通常需要为每一种类型编写一个函数:
fn add_i8(a:i8, b:i8) -> i8 { a + b } fn add_i32(a:i32, b:i32) -> i32 { a + b } fn add_f64(a:f64, b:f64) -> f64 { a + b } fn main() { println!("add i8: {}", add_i8(2i8, 3i8)); println!("add i32: {}", add_i32(20, 30)); println!("add f64: {}", add_f64(1.23, 1.23)); }
上述代码可以正常运行,但是很啰嗦,如果你要支持更多的类型,那么会更繁琐。程序员或多或少都有强迫症,一个好程序员的公认特征就是 —— 懒,这么勤快的写一大堆代码,显然不是咱们的优良传统,是不?
在开始讲解 Rust 的泛型之前,先来看看什么是多态。
在编程的时候,我们经常利用多态。通俗的讲,多态就是好比坦克的炮管,既可以发射普通弹药,也可以发射制导炮弹(导弹),也可以发射贫铀穿甲弹,甚至发射子母弹,没有必要为每一种炮弹都在坦克上分别安装一个专用炮管,即使生产商愿意,炮手也不愿意,累死人啊。所以在编程开发中,我们也需要这样“通用的炮管”,这个“通用的炮管”就是多态。
实际上,泛型就是一种多态。泛型主要目的是为程序员提供编程的便利,减少代码的臃肿,同时可以极大地丰富语言本身的表达能力,为程序员提供了一个合适的炮管。想想,一个函数,可以代替几十个,甚至数百个函数,是一件多么让人兴奋的事情:
fn add<T>(a:T, b:T) -> T { a + b } fn main() { println!("add i8: {}", add(2i8, 3i8)); println!("add i32: {}", add(20, 30)); println!("add f64: {}", add(1.23, 1.23)); }
将之前的代码改成上面这样,就是 Rust 泛型的初印象,这段代码虽然很简洁,但是并不能编译通过,我们会在后面进行详细讲解,现在只要对泛型有个大概的印象即可。
泛型详解
上面代码的 T 就是泛型参数,实际上在 Rust 中,泛型参数的名称你可以任意起,但是出于惯例,我们都用 T ( T 是 type 的首字母)来作为首选,这个名称越短越好,除非需要表达含义,否则一个字母是最完美的。
使用泛型参数,有一个先决条件,必需在使用前对其进行声明:
#![allow(unused)] fn main() { fn largest<T>(list: &[T]) -> T { }
该泛型函数的作用是从列表中找出最大的值,其中列表中的元素类型为 T。首先 largest<T> 对泛型参数 T 进行了声明,然后才在函数参数中进行使用该泛型参数 list: &[T] (还记得 &[T] 类型吧?这是数组切片)。
总之,我们可以这样理解这个函数定义:函数 largest 有泛型类型 T,它有个参数 list,其类型是元素为 T 的数组切片,最后,该函数返回值的类型也是 T。
具体的泛型函数实现如下:
fn largest<T>(list: &[T]) -> T { let mut largest = list[0]; for &item in list.iter() { if item > largest { largest = item; } } largest } fn main() { let number_list = vec![34, 50, 25, 100, 65]; let result = largest(&number_list); println!("The largest number is {}", result); let char_list = vec!['y', 'm', 'a', 'q']; let result = largest(&char_list); println!("The largest char is {}", result); }
运行后报错:
error[E0369]: binary operation `>` cannot be applied to type `T` // `>`操作符不能用于类型`T`
--> src/main.rs:5:17
|
5 | if item > largest {
| ---- ^ ------- T
| |
| T
|
help: consider restricting type parameter `T` // 考虑对T进行类型上的限制 :
|
1 | fn largest<T: std::cmp::PartialOrd>(list: &[T]) -> T {
| ++++++++++++++++++++++
因为 T 可以是任何类型,但不是所有的类型都能进行比较,因此上面的错误中,编译器建议我们给 T 添加一个类型限制:使用 std::cmp::PartialOrd 特征(Trait)对 T 进行限制,特征在下一节会详细介绍,现在你只要理解,该特征的目的就是让类型实现可比较的功能。
还记得我们一开始的 add 泛型函数吗?如果你运行它,会得到以下的报错:
error[E0369]: cannot add `T` to `T` // 无法将 `T` 类型跟 `T` 类型进行相加
--> src/main.rs:2:7
|
2 | a + b
| - ^ - T
| |
| T
|
help: consider restricting type parameter `T`
|
1 | fn add<T: std::ops::Add<Output = T>>(a:T, b:T) -> T {
| +++++++++++++++++++++++++++
同样的,不是所有 T 类型都能进行相加操作,因此我们需要用 std::ops::Add<Output = T> 对 T 进行限制:
#![allow(unused)] fn main() { fn add<T: std::ops::Add<Output = T>>(a:T, b:T) -> T { a + b } }
进行如上修改后,就可以正常运行。
结构体中使用泛型
结构体中的字段类型也可以用泛型来定义,下面代码定义了一个坐标点 Point,它可以存放任何类型的坐标值:
struct Point<T> { x: T, y: T, } fn main() { let integer = Point { x: 5, y: 10 }; let float = Point { x: 1.0, y: 4.0 }; }
这里有两点需要特别的注意:
- 提前声明,跟泛型函数定义类似,首先我们在使用泛型参数之前必需要进行声明
Point<T>,接着就可以在结构体的字段类型中使用T来替代具体的类型 - x 和 y 是相同的类型
第二点非常重要,如果使用不同的类型,那么它会导致下面代码的报错:
struct Point<T> { x: T, y: T, } fn main() { let p = Point{x: 1, y :1.1}; }
错误如下:
error[E0308]: mismatched types //类型不匹配
--> src/main.rs:7:28
|
7 | let p = Point{x: 1, y :1.1};
| ^^^ expected integer, found floating-point number //期望y是整数,但是却是浮点数
当把 1 赋值给 x 时,变量 p 的 T 类型就被确定为整数类型,因此 y 也必须是整数类型,但是我们却给它赋予了浮点数,因此导致报错。
如果想让 x 和 y 既能类型相同,又能类型不同,就需要使用不同的泛型参数:
struct Point<T,U> { x: T, y: U, } fn main() { let p = Point{x: 1, y :1.1}; }
切记,所有的泛型参数都要提前声明:Point<T,U> ! 但是如果你的结构体变成这鬼样:struct Woo<T,U,V,W,X>,那么你需要考虑拆分这个结构体,减少泛型参数的个数和代码复杂度。
枚举中使用泛型
提到枚举类型,Option 永远是第一个应该被想起来的,在之前的章节中,它也多次出现:
#![allow(unused)] fn main() { enum Option<T> { Some(T), None, } }
Option<T> 是一个拥有泛型 T 的枚举类型,它第一个成员是 Some(T),存放了一个类型为 T 的值。得益于泛型的引入,我们可以在任何一个需要返回值的函数中,去使用 Option<T> 枚举类型来做为返回值,用于返回一个任意类型的值 Some(T),或者没有值 None。
对于枚举而言,卧龙凤雏永远是绕不过去的存在:如果是 Option 是卧龙,那么 Result 就一定是凤雏,得两者可得天下:
#![allow(unused)] fn main() { enum Result<T, E> { Ok(T), Err(E), } }
这个枚举和 Option 一样,主要用于函数返回值,与 Option 用于值的存在与否不同,Result 关注的主要是值的正确性。
如果函数正常运行,则最后返回一个 Ok(T),T 是函数具体的返回值类型,如果函数异常运行,则返回一个 Err(E),E 是错误类型。例如打开一个文件:如果成功打开文件,则返回 Ok(std::fs::File),因此 T 对应的是 std::fs::File 类型;而当打开文件时出现问题时,返回 Err(std::io::Error),E 对应的就是 std::io::Error 类型。
方法中使用泛型
上一章中,我们讲到什么是方法以及如何在结构体和枚举上定义方法。方法上也可以使用泛型:
struct Point<T> { x: T, y: T, } impl<T> Point<T> { fn x(&self) -> &T { &self.x } } fn main() { let p = Point { x: 5, y: 10 }; println!("p.x = {}", p.x()); }
使用泛型参数前,依然需要提前声明:impl<T>,只有提前声明了,我们才能在Point<T>中使用它,这样 Rust 就知道 Point 的尖括号中的类型是泛型而不是具体类型。需要注意的是,这里的 Point<T> 不再是泛型声明,而是一个完整的结构体类型,因为我们定义的结构体就是 Point<T> 而不再是 Point。
除了结构体中的泛型参数,我们还能在该结构体的方法中定义额外的泛型参数,就跟泛型函数一样:
struct Point<T, U> { x: T, y: U, } impl<T, U> Point<T, U> { fn mixup<V, W>(self, other: Point<V, W>) -> Point<T, W> { Point { x: self.x, y: other.y, } } } fn main() { let p1 = Point { x: 5, y: 10.4 }; let p2 = Point { x: "Hello", y: 'c'}; let p3 = p1.mixup(p2); println!("p3.x = {}, p3.y = {}", p3.x, p3.y); }
这个例子中,T,U 是定义在结构体 Point 上的泛型参数,V,W 是单独定义在方法 mixup 上的泛型参数,它们并不冲突,说白了,你可以理解为,一个是结构体泛型,一个是函数泛型。
为具体的泛型类型实现方法
对于 Point<T> 类型,你不仅能定义基于 T 的方法,还能针对特定的具体类型,进行方法定义:
#![allow(unused)] fn main() { impl Point<f32> { fn distance_from_origin(&self) -> f32 { (self.x.powi(2) + self.y.powi(2)).sqrt() } } }
这段代码意味着 Point<f32> 类型会有一个方法 distance_from_origin,而其他 T 不是 f32 类型的 Point<T> 实例则没有定义此方法。这个方法计算点实例与坐标(0.0, 0.0) 之间的距离,并使用了只能用于浮点型的数学运算符。
这样我们就能针对特定的泛型类型实现某个特定的方法,对于其它泛型类型则没有定义该方法。
const 泛型(Rust 1.51 版本引入的重要特性)
在之前的泛型中,可以抽象为一句话:针对类型实现的泛型,所有的泛型都是为了抽象不同的类型,那有没有针对值的泛型?可能很多同学感觉很难理解,值怎么使用泛型?不急,我们先从数组讲起。
在数组那节,有提到过很重要的一点:[i32; 2] 和 [i32; 3] 是不同的数组类型,比如下面的代码:
fn display_array(arr: [i32; 3]) { println!("{:?}", arr); } fn main() { let arr: [i32; 3] = [1, 2, 3]; display_array(arr); let arr: [i32;2] = [1,2]; display_array(arr); }
运行后报错:
error[E0308]: mismatched types // 类型不匹配
--> src/main.rs:10:19
|
10 | display_array(arr);
| ^^^ expected an array with a fixed size of 3 elements, found one with 2 elements
// 期望一个长度为3的数组,却发现一个长度为2的
结合代码和报错,可以很清楚的看出,[i32; 3] 和 [i32; 2] 确实是两个完全不同的类型,因此无法用同一个函数调用。
首先,让我们修改代码,让 display_array 能打印任意长度的 i32 数组:
fn display_array(arr: &[i32]) { println!("{:?}", arr); } fn main() { let arr: [i32; 3] = [1, 2, 3]; display_array(&arr); let arr: [i32;2] = [1,2]; display_array(&arr); }
很简单,只要使用数组切片,然后传入 arr 的不可变引用即可。
接着,将 i32 改成所有类型的数组:
fn display_array<T: std::fmt::Debug>(arr: &[T]) { println!("{:?}", arr); } fn main() { let arr: [i32; 3] = [1, 2, 3]; display_array(&arr); let arr: [i32;2] = [1,2]; display_array(&arr); }
也不难,唯一要注意的是需要对 T 加一个限制 std::fmt::Debug,该限制表明 T 可以用在 println!("{:?}", arr) 中,因为 {:?} 形式的格式化输出需要 arr 实现该特征。
通过引用,我们可以很轻松的解决处理任何类型数组的问题,但是如果在某些场景下引用不适宜用或者干脆不能用呢?你们知道为什么以前 Rust 的一些数组库,在使用的时候都限定长度不超过 32 吗?因为它们会为每个长度都单独实现一个函数,简直。。。毫无人性。难道没有什么办法可以解决这个问题吗?
好在,现在咱们有了 const 泛型,也就是针对值的泛型,正好可以用于处理数组长度的问题:
fn display_array<T: std::fmt::Debug, const N: usize>(arr: [T; N]) { println!("{:?}", arr); } fn main() { let arr: [i32; 3] = [1, 2, 3]; display_array(arr); let arr: [i32; 2] = [1, 2]; display_array(arr); }
如上所示,我们定义了一个类型为 [T; N] 的数组,其中 T 是一个基于类型的泛型参数,这个和之前讲的泛型没有区别,而重点在于 N 这个泛型参数,它是一个基于值的泛型参数!因为它用来替代的是数组的长度。
N 就是 const 泛型,定义的语法是 const N: usize,表示 const 泛型 N ,它基于的值类型是 usize。
在泛型参数之前,Rust 完全不适合复杂矩阵的运算,自从有了 const 泛型,一切即将改变。
const 泛型表达式
假设我们某段代码需要在内存很小的平台上工作,因此需要限制函数参数占用的内存大小,此时就可以使用 const 泛型表达式来实现:
// 目前只能在nightly版本下使用 #![allow(incomplete_features)] #![feature(generic_const_exprs)] fn something<T>(val: T) where Assert<{ core::mem::size_of::<T>() < 768 }>: IsTrue, // ^-----------------------------^ 这里是一个 const 表达式,换成其它的 const 表达式也可以 { // } fn main() { something([0u8; 0]); // ok something([0u8; 512]); // ok something([0u8; 1024]); // 编译错误,数组长度是1024字节,超过了768字节的参数长度限制 } // --- pub enum Assert<const CHECK: bool> { // } pub trait IsTrue { // } impl IsTrue for Assert<true> { // }
const fn
@todo
泛型的性能
在 Rust 中泛型是零成本的抽象,意味着你在使用泛型时,完全不用担心性能上的问题。
但是任何选择都是权衡得失的,既然我们获得了性能上的巨大优势,那么又失去了什么呢?Rust 是在编译期为泛型对应的多个类型,生成各自的代码,因此损失了编译速度和增大了最终生成文件的大小。
具体来说:
Rust 通过在编译时进行泛型代码的 单态化(monomorphization)来保证效率。单态化是一个通过填充编译时使用的具体类型,将通用代码转换为特定代码的过程。
编译器所做的工作正好与我们创建泛型函数的步骤相反,编译器寻找所有泛型代码被调用的位置并针对具体类型生成代码。
让我们看看一个使用标准库中 Option 枚举的例子:
#![allow(unused)] fn main() { let integer = Some(5); let float = Some(5.0); }
当 Rust 编译这些代码的时候,它会进行单态化。编译器会读取传递给 Option<T> 的值并发现有两种 Option<T>:一种对应 i32 另一种对应 f64。为此,它会将泛型定义 Option<T> 展开为 Option_i32 和 Option_f64,接着将泛型定义替换为这两个具体的定义。
编译器生成的单态化版本的代码看起来像这样:
enum Option_i32 { Some(i32), None, } enum Option_f64 { Some(f64), None, } fn main() { let integer = Option_i32::Some(5); let float = Option_f64::Some(5.0); }
我们可以使用泛型来编写不重复的代码,而 Rust 将会为每一个实例编译其特定类型的代码。这意味着在使用泛型时没有运行时开销;当代码运行,它的执行效率就跟好像手写每个具体定义的重复代码一样。这个单态化过程正是 Rust 泛型在运行时极其高效的原因。
特征 Trait
如果我们想定义一个文件系统,那么把该系统跟底层存储解耦是很重要的。文件操作主要包含四个:open 、write、read、close,这些操作可以发生在硬盘,可以发生在内存,还可以发生在网络IO甚至(...我实在编不下去了,大家来帮帮我)。总之如果你要为每一种情况都单独实现一套代码,那这种实现将过于繁杂,而且也没那个必要。
要解决上述问题,需要把这些行为抽象出来,就要使用 Rust 中的特征 trait 概念。可能你是第一次听说这个名词,但是不要怕,如果学过其他语言,那么大概率你听说过接口,没错,特征跟接口很类似。
在之前的代码中,我们也多次见过特征的使用,例如 #[derive(Debug)],它在我们定义的类型(struct)上自动派生 Debug 特征,接着可以使用 println!("{:?}", x) 打印这个类型;再例如:
#![allow(unused)] fn main() { fn add<T: std::ops::Add<Output = T>>(a:T, b:T) -> T { a + b } }
通过 std::ops::Add 特征来限制 T,只有 T 实现了 std::ops::Add 才能进行合法的加法操作,毕竟不是所有的类型都能进行相加。
这些都说明一个道理,特征定义了一个可以被共享的行为,只要实现了特征,你就能使用该行为。
定义特征
如果不同的类型具有相同的行为,那么我们就可以定义一个特征,然后为这些类型实现该特征。定义特征是把一些方法组合在一起,目的是定义一个实现某些目标所必需的行为的集合。
例如,我们现在有文章 Post 和微博 Weibo 两种内容载体,而我们想对相应的内容进行总结,也就是无论是文章内容,还是微博内容,都可以在某个时间点进行总结,那么总结这个行为就是共享的,因此可以用特征来定义:
#![allow(unused)] fn main() { pub trait Summary { fn summarize(&self) -> String; } }
这里使用 trait 关键字来声明一个特征,Summary 是特征名。在大括号中定义了该特征的所有方法,在这个例子中是: fn summarize(&self) -> String。
特征只定义行为看起来是什么样的,而不定义行为具体是怎么样的。因此,我们只定义特征方法的签名,而不进行实现,此时方法签名结尾是 ;,而不是一个 {}。
接下来,每一个实现这个特征的类型都需要具体实现该特征的相应方法,编译器也会确保任何实现 Summary 特征的类型都拥有与这个签名的定义完全一致的 summarize 方法。
为类型实现特征
因为特征只定义行为看起来是什么样的,因此我们需要为类型实现具体的特征,定义行为具体是怎么样的。
首先来为 Post 和 Weibo 实现 Summary 特征:
#![allow(unused)] fn main() { pub trait Summary { fn summarize(&self) -> String; } pub struct Post { pub title: String, // 标题 pub author: String, // 作者 pub content: String, // 内容 } impl Summary for Post { fn summarize(&self) -> String { format!("文章{}, 作者是{}", self.title, self.author) } } pub struct Weibo { pub username: String, pub content: String } impl Summary for Weibo { fn summarize(&self) -> String { format!("{}发表了微博{}", self.username, self.content) } } }
实现特征的语法与为结构体、枚举实现方法很像:impl Summary for Post,读作“为 Post 类型实现 Summary 特征”,然后在 impl 的花括号中实现该特征的具体方法。
接下来就可以在这个类型上调用特征的方法:
fn main() { let post = Post{title: "Rust语言简介".to_string(),author: "Sunface".to_string(), content: "Rust棒极了!".to_string()}; let weibo = Weibo{username: "sunface".to_string(),content: "好像微博没Tweet好用".to_string()}; println!("{}",post.summarize()); println!("{}",weibo.summarize()); }
运行输出:
文章 Rust 语言简介, 作者是Sunface
sunface发表了微博好像微博没Tweet好用
说实话,如果特征仅仅如此,你可能会觉得花里胡哨没啥用,接下来就让你见识下 trait 真正的威力。
特征定义与实现的位置(孤儿规则)
上面我们将 Summary 定义成了 pub 公开的。这样,如果他人想要使用我们的 Summary 特征,则可以引入到他们的包中,然后再进行实现。
关于特征实现与定义的位置,有一条非常重要的原则:如果你想要为类型 A 实现特征 T,那么 A 或者 T 至少有一个是在当前作用域中定义的! 例如我们可以为上面的 Post 类型实现标准库中的 Display 特征,这是因为 Post 类型定义在当前的作用域中。同时,我们也可以在当前包中为 String 类型实现 Summary 特征,因为 Summary 定义在当前作用域中。
但是你无法在当前作用域中,为 String 类型实现 Display 特征,因为它们俩都定义在标准库中,其定义所在的位置都不在当前作用域,跟你半毛钱关系都没有,看看就行了。
该规则被称为孤儿规则,可以确保其它人编写的代码不会破坏你的代码,也确保了你不会莫名其妙就破坏了风马牛不相及的代码。
默认实现
你可以在特征中定义具有默认实现的方法,这样其它类型无需再实现该方法,或者也可以选择重载该方法:
#![allow(unused)] fn main() { pub trait Summary { fn summarize(&self) -> String { String::from("(Read more...)") } } }
上面为 Summary 定义了一个默认实现,下面我们编写段代码来测试下:
#![allow(unused)] fn main() { impl Summary for Post {} impl Summary for Weibo { fn summarize(&self) -> String { format!("{}发表了微博{}", self.username, self.content) } } }
可以看到,Post 选择了默认实现,而 Weibo 重载了该方法,调用和输出如下:
#![allow(unused)] fn main() { println!("{}",post.summarize()); println!("{}",weibo.summarize()); }
(Read more...)
sunface发表了微博好像微博没Tweet好用
默认实现允许调用相同特征中的其他方法,哪怕这些方法没有默认实现。如此,特征可以提供很多有用的功能而只需要实现指定的一小部分内容。例如,我们可以定义 Summary 特征,使其具有一个需要实现的 summarize_author 方法,然后定义一个 summarize 方法,此方法的默认实现调用 summarize_author 方法:
#![allow(unused)] fn main() { pub trait Summary { fn summarize_author(&self) -> String; fn summarize(&self) -> String { format!("(Read more from {}...)", self.summarize_author()) } } }
为了使用 Summary,只需要实现 summarize_author 方法即可:
#![allow(unused)] fn main() { impl Summary for Weibo { fn summarize_author(&self) -> String { format!("@{}", self.username) } } println!("1 new weibo: {}", weibo.summarize()); }
weibo.summarize() 会先调用 Summary 特征默认实现的 summarize 方法,通过该方法进而调用 Weibo 为 Summary 实现的 summarize_author 方法,最终输出:1 new weibo: (Read more from @horse_ebooks...)。
使用特征作为函数参数
之前提到过,特征如果仅仅是用来实现方法,那真的有些大材小用,现在我们来讲下,真正可以让特征大放光彩的地方。
现在,先定义一个函数,使用特征作为函数参数:
#![allow(unused)] fn main() { pub fn notify(item: &impl Summary) { println!("Breaking news! {}", item.summarize()); } }
impl Summary,只能说想出这个类型的人真的是起名鬼才,简直太贴切了,故名思义,它的意思是 实现了Summary特征 的 item 参数。
你可以使用任何实现了 Summary 特征的类型作为该函数的参数,同时在函数体内,还可以调用该特征的方法,例如 summarize 方法。具体的说,可以传递 Post 或 Weibo 的实例来作为参数,而其它类如 String 或者 i32 的类型则不能用做该函数的参数,因为它们没有实现 Summary 特征。
特征约束(trait bound)
虽然 impl Trait 这种语法非常好理解,但是实际上它只是一个语法糖:
#![allow(unused)] fn main() { pub fn notify<T: Summary>(item: &T) { println!("Breaking news! {}", item.summarize()); } }
真正的完整书写形式如上所述,形如 T: Summary 被称为特征约束。
在简单的场景下 impl Trait 这种语法糖就足够使用,但是对于复杂的场景,特征约束可以让我们拥有更大的灵活性和语法表现能力,例如一个函数接受两个 impl Summary 的参数:
#![allow(unused)] fn main() { pub fn notify(item1: &impl Summary, item2: &impl Summary) {} }
如果函数两个参数是不同的类型,那么上面的方法很好,只要这两个类型都实现了 Summary 特征即可。但是如果我们想要强制函数的两个参数是同一类型呢?上面的语法就无法做到这种限制,此时我们只能使特征约束来实现:
#![allow(unused)] fn main() { pub fn notify<T: Summary>(item1: &T, item2: &T) {} }
泛型类型 T 说明了 item1 和 item2 必须拥有同样的类型,同时 T: Summary 说明了 T 必须实现 Summary 特征。
多重约束
除了单个约束条件,我们还可以指定多个约束条件,例如除了让参数实现 Summary 特征外,还可以让参数实现 Display 特征以控制它的格式化输出:
#![allow(unused)] fn main() { pub fn notify(item: &(impl Summary + Display)) {} }
除了上述的语法糖形式,还能使用特征约束的形式:
#![allow(unused)] fn main() { pub fn notify<T: Summary + Display>(item: &T) {} }
通过这两个特征,就可以使用 item.summarize 方法,以及通过 println!("{}", item) 来格式化输出 item。
Where 约束
当特征约束变得很多时,函数的签名将变得很复杂:
#![allow(unused)] fn main() { fn some_function<T: Display + Clone, U: Clone + Debug>(t: &T, u: &U) -> i32 {} }
严格来说,上面的例子还是不够复杂,但是我们还是能对其做一些形式上的改进,通过 where:
#![allow(unused)] fn main() { fn some_function<T, U>(t: &T, u: &U) -> i32 where T: Display + Clone, U: Clone + Debug {} }
使用特征约束有条件地实现方法或特征
特征约束,可以让我们在指定类型 + 指定特征的条件下去实现方法,例如:
#![allow(unused)] fn main() { use std::fmt::Display; struct Pair<T> { x: T, y: T, } impl<T> Pair<T> { fn new(x: T, y: T) -> Self { Self { x, y, } } } impl<T: Display + PartialOrd> Pair<T> { fn cmp_display(&self) { if self.x >= self.y { println!("The largest member is x = {}", self.x); } else { println!("The largest member is y = {}", self.y); } } } }
cmp_display 方法,并不是所有的 Pair<T> 结构体对象都可以拥有,只有 T 同时实现了 Display + PartialOrd 的 Pair<T> 才可以拥有此方法。
该函数可读性会更好,因为泛型参数、参数、返回值都在一起,可以快速的阅读,同时每个泛型参数的特征也在新的代码行中通过特征约束进行了约束。
也可以有条件地实现特征, 例如,标准库为任何实现了 Display 特征的类型实现了 ToString 特征:
#![allow(unused)] fn main() { impl<T: Display> ToString for T { // --snip-- } }
我们可以对任何实现了 Display 特征的类型调用由 ToString 定义的 to_string 方法。例如,可以将整型转换为对应的 String 值,因为整型实现了 Display:
#![allow(unused)] fn main() { let s = 3.to_string(); }
函数返回中的 impl Trait
可以通过 impl Trait 来说明一个函数返回了一个类型,该类型实现了某个特征:
#![allow(unused)] fn main() { fn returns_summarizable() -> impl Summary { Weibo { username: String::from("sunface"), content: String::from( "m1 max太厉害了,电脑再也不会卡", ) } } }
因为 Weibo 实现了 Summary,因此这里可以用它来作为返回值。要注意的是,虽然我们知道这里是一个 Weibo 类型,但是对于 returns_summarizable 的调用者而言,他只知道返回了一个实现了 Summary 特征的对象,但是并不知道返回了一个 Weibo 类型。
这种 impl Trait 形式的返回值,在一种场景下非常非常有用,那就是返回的真实类型非常复杂,你不知道该怎么声明时(毕竟 Rust 要求你必须标出所有的类型),此时就可以用 impl Trait 的方式简单返回。例如,闭包和迭代器就是很复杂,只有编译器才知道那玩意的真实类型,如果让你写出来它们的具体类型,估计内心有一万只草泥马奔腾,好在你可以用 impl Iterator 来告诉调用者,返回了一个迭代器,因为所有迭代器都会实现 Iterator 特征。
但是这种返回值方式有一个很大的限制:只能有一个具体的类型,例如:
#![allow(unused)] fn main() { fn returns_summarizable(switch: bool) -> impl Summary { if switch { Post { title: String::from( "Penguins win the Stanley Cup Championship!", ), author: String::from("Iceburgh"), content: String::from( "The Pittsburgh Penguins once again are the best \ hockey team in the NHL.", ), } } else { Weibo { username: String::from("horse_ebooks"), content: String::from( "of course, as you probably already know, people", ), } } } }
以上的代码就无法通过编译,因为它返回了两个不同的类型 Post 和 Weibo。
`if` and `else` have incompatible types
expected struct `Post`, found struct `Weibo`
报错提示我们 if 和 else 返回了不同的类型。如果想要实现返回不同的类型,需要使用下一章节中的特征对象。
修复上一节中的 largest 函数
还记得上一节中的例子吧,当时留下一个疑问,该如何解决编译报错:
#![allow(unused)] fn main() { error[E0369]: binary operation `>` cannot be applied to type `T` // 无法在 `T` 类型上应用`>`运算符 --> src/main.rs:5:17 | 5 | if item > largest { | ---- ^ ------- T | | | T | help: consider restricting type parameter `T` // 考虑使用以下的特征来约束 `T` | 1 | fn largest<T: std::cmp::PartialOrd>(list: &[T]) -> T { | ^^^^^^^^^^^^^^^^^^^^^^ }
在 largest 函数体中我们想要使用大于运算符(>)比较两个 T 类型的值。这个运算符是标准库中特征 std::cmp::PartialOrd 的一个默认方法。所以需要在 T 的特征约束中指定 PartialOrd,这样 largest 函数可以用于内部元素类型可比较大小的数组切片。
由于 PartialOrd 位于 prelude 中所以并不需要通过 std::cmp 手动将其引入作用域。所以可以将 largest 的签名修改为如下:
#![allow(unused)] fn main() { fn largest<T: PartialOrd>(list: &[T]) -> T {} }
但是此时编译,又会出现新的错误:
#![allow(unused)] fn main() { error[E0508]: cannot move out of type `[T]`, a non-copy slice --> src/main.rs:2:23 | 2 | let mut largest = list[0]; | ^^^^^^^ | | | cannot move out of here | help: consider using a reference instead: `&list[0]` error[E0507]: cannot move out of borrowed content --> src/main.rs:4:9 | 4 | for &item in list.iter() { | ^---- | || | |hint: to prevent move, use `ref item` or `ref mut item` | cannot move out of borrowed content }
错误的核心是 cannot move out of type [T], a non-copy slice,原因是 T 没有实现 Copy 特性,因此我们只能把所有权进行转移,毕竟只有 i32 等基础类型才实现了 Copy 特性,可以存储在栈上,而 T 可以指代任何类型(严格来说是实现了 PartialOrd 特征的所有类型)。
因此,为了让 T 拥有 Copy 特性,我们可以增加特征约束:
fn largest<T: PartialOrd + Copy>(list: &[T]) -> T { let mut largest = list[0]; for &item in list.iter() { if item > largest { largest = item; } } largest } fn main() { let number_list = vec![34, 50, 25, 100, 65]; let result = largest(&number_list); println!("The largest number is {}", result); let char_list = vec!['y', 'm', 'a', 'q']; let result = largest(&char_list); println!("The largest char is {}", result); }
如果并不希望限制 largest 函数只能用于实现了 Copy 特征的类型,我们可以在 T 的特征约束中指定 Clone 特征 而不是 Copy 特征。并克隆 list 中的每一个值使得 largest 函数拥有其所有权。使用 clone 函数意味着对于类似 String 这样拥有堆上数据的类型,会潜在地分配更多堆上空间,而堆分配在涉及大量数据时可能会相当缓慢。
另一种 largest 的实现方式是返回在 list 中 T 值的引用。如果我们将函数返回值从 T 改为 &T 并改变函数体使其能够返回一个引用,我们将不需要任何 Clone 或 Copy 的特征约束而且也不会有任何的堆分配。尝试自己实现这种替代解决方式吧!
通过 derive 派生特征
在本书中,形如 #[derive(Debug)] 的代码已经出现了很多次,这种是一种特征派生语法,被 derive 标记的对象会自动实现对应的默认特征代码,继承相应的功能。
例如 Debug 特征,它有一套自动实现的默认代码,当你给一个结构体标记后,就可以使用 println!("{:?}", s) 的形式打印该结构体的对象。
再如 Copy 特征,它也有一套自动实现的默认代码,当标记到一个类型上时,可以让这个类型自动实现 Copy 特征,进而可以调用 copy 方法,进行自我复制。
总之,derive 派生出来的是 Rust 默认给我们提供的特征,在开发过程中极大的简化了自己手动实现相应特征的需求,当然,如果你有特殊的需求,还可以自己手动重载该实现。
详细的 derive 列表参见附录-派生特征。
调用方法需要引入特征
在一些场景中,使用 as 关键字做类型转换会有比较大的限制,因为你想要在类型转换上拥有完全的控制,例如处理转换错误,那么你将需要 TryInto:
use std::convert::TryInto; fn main() { let a: i32 = 10; let b: u16 = 100; let b_ = b.try_into() .unwrap(); if a < b_ { println!("Ten is less than one hundred."); } }
上面代码中引入了 std::convert::TryInto 特征,但是却没有使用它,可能有些同学会为此困惑,主要原因在于如果你要使用一个特征的方法,那么你需要将该特征引入当前的作用域中,我们在上面用到了 try_into 方法,因此需要引入对应的特征。
但是 Rust 又提供了一个非常便利的办法,即把最常用的标准库中的特征通过 std::prelude 模块提前引入到当前作用域中,其中包括了 std::convert::TryInto,你可以尝试删除第一行的代码 use ...,看看是否会报错。
几个综合例子
为自定义类型实现 + 操作
在 Rust 中除了数值类型的加法,String 也可以做加法,因为 Rust 为该类型实现了 std::ops::Add 特征,同理,如果我们为自定义类型实现了该特征,那就可以自己实现 Point1 + Point2 的操作:
use std::ops::Add; // 为Point结构体派生Debug特征,用于格式化输出 #[derive(Debug)] struct Point<T: Add<T, Output = T>> { //限制类型T必须实现了Add特征,否则无法进行+操作。 x: T, y: T, } impl<T: Add<T, Output = T>> Add for Point<T> { type Output = Point<T>; fn add(self, p: Point<T>) -> Point<T> { Point{ x: self.x + p.x, y: self.y + p.y, } } } fn add<T: Add<T, Output=T>>(a:T, b:T) -> T { a + b } fn main() { let p1 = Point{x: 1.1f32, y: 1.1f32}; let p2 = Point{x: 2.1f32, y: 2.1f32}; println!("{:?}", add(p1, p2)); let p3 = Point{x: 1i32, y: 1i32}; let p4 = Point{x: 2i32, y: 2i32}; println!("{:?}", add(p3, p4)); }
自定义类型的打印输出
在开发过程中,往往只要使用 #[derive(Debug)] 对我们的自定义类型进行标注,即可实现打印输出的功能:
#[derive(Debug)] struct Point{ x: i32, y: i32 } fn main() { let p = Point{x:3,y:3}; println!("{:?}",p); }
但是在实际项目中,往往需要对我们的自定义类型进行自定义的格式化输出,以让用户更好的阅读理解我们的类型,此时就要为自定义类型实现 std::fmt::Display 特征:
#![allow(dead_code)] use std::fmt; use std::fmt::{Display}; #[derive(Debug,PartialEq)] enum FileState { Open, Closed, } #[derive(Debug)] struct File { name: String, data: Vec<u8>, state: FileState, } impl Display for FileState { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { FileState::Open => write!(f, "OPEN"), FileState::Closed => write!(f, "CLOSED"), } } } impl Display for File { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "<{} ({})>", self.name, self.state) } } impl File { fn new(name: &str) -> File { File { name: String::from(name), data: Vec::new(), state: FileState::Closed, } } } fn main() { let f6 = File::new("f6.txt"); //... println!("{:?}", f6); println!("{}", f6); }
以上两个例子较为复杂,目的是为读者展示下真实的使用场景长什么样,因此需要读者细细阅读,最终消化这些知识对于你的 Rust 之路会有莫大的帮助。
特征对象
在上一节中有一段代码无法通过编译:
#![allow(unused)] fn main() { fn returns_summarizable(switch: bool) -> impl Summary { if switch { Post { // ... } } else { Weibo { // ... } } } }
其中 Post 和 Weibo 都实现了 Summary 特征,因此上面的函数试图通过返回 impl Summary 来返回这两个类型,但是编译器却无情地报错了,原因是 impl Trait 的返回值类型并不支持多种不同的类型返回,那如果我们想返回多种类型,该怎么办?
再来考虑一个问题:现在在做一款游戏,需要将多个对象渲染在屏幕上,这些对象属于不同的类型,存储在列表中,渲染的时候,需要循环该列表并顺序渲染每个对象,在 Rust 中该怎么实现?
聪明的同学可能已经能想到一个办法,利用枚举:
#[derive(Debug)] enum UiObject { Button, SelectBox, } fn main() { let objects = [ UiObject::Button, UiObject::SelectBox ]; for o in objects { draw(o) } } fn draw(o: UiObject) { println!("{:?}",o); }
Bingo,这个确实是一个办法,但是问题来了,如果你的对象集合并不能事先明确地知道呢?或者别人想要实现一个 UI 组件呢?此时枚举中的类型是有些缺少的,是不是还要修改你的代码增加一个枚举成员?
总之,在编写这个 UI 库时,我们无法知道所有的 UI 对象类型,只知道的是:
- UI 对象的类型不同
- 需要一个统一的类型来处理这些对象,无论是作为函数参数还是作为列表中的一员
- 需要对每一个对象调用
draw方法
在拥有继承的语言中,可以定义一个名为 Component 的类,该类上有一个 draw 方法。其他的类比如 Button、Image 和 SelectBox 会从 Component 派生并因此继承 draw 方法。它们各自都可以覆盖 draw 方法来定义自己的行为,但是框架会把所有这些类型当作是 Component 的实例,并在其上调用 draw。不过 Rust 并没有继承,我们得另寻出路。
特征对象定义
为了解决上面的所有问题,Rust 引入了一个概念 —— 特征对象。
在介绍特征对象之前,先来为之前的 UI 组件定义一个特征:
#![allow(unused)] fn main() { pub trait Draw { fn draw(&self); } }
只要组件实现了 Draw 特征,就可以调用 draw 方法来进行渲染。假设有一个 Button 和 SelectBox 组件实现了 Draw 特征:
#![allow(unused)] fn main() { pub struct Button { pub width: u32, pub height: u32, pub label: String, } impl Draw for Button { fn draw(&self) { // 绘制按钮的代码 } } struct SelectBox { width: u32, height: u32, options: Vec<String>, } impl Draw for SelectBox { fn draw(&self) { // 绘制SelectBox的代码 } } }
此时,还需要一个动态数组来存储这些 UI 对象:
#![allow(unused)] fn main() { pub struct Screen { pub components: Vec<?>, } }
注意到上面代码中的 ? 吗?它的意思是:我们应该填入什么类型,可以说就之前学过的内容里,你找不到哪个类型可以填入这里,但是因为 Button 和 SelectBox 都实现了 Draw 特征,那我们是不是可以把 Draw 特征的对象作为类型,填入到数组中呢?答案是肯定的。
特征对象指向实现了 Draw 特征的类型的实例,也就是指向了 Button 或者 SelectBox 的实例,这种映射关系是存储在一张表中,可以在运行时通过特征对象找到具体调用的类型方法。
可以通过 & 引用或者 Box<T> 智能指针的方式来创建特征对象。
Box<T>在后面章节会详细讲解,大家现在把它当成一个引用即可,只不过它包裹的值会被强制分配在堆上
trait Draw { fn draw(&self) -> String; } impl Draw for u8 { fn draw(&self) -> String { format!("u8: {}", *self) } } impl Draw for f64 { fn draw(&self) -> String { format!("f64: {}", *self) } } // 若 T 实现了 Draw 特征, 则调用该函数时传入的 Box<T> 可以被隐式转换成函数参数签名中的 Box<dyn Draw> fn draw1(x: Box<dyn Draw>) { // 由于实现了 Deref 特征,Box 智能指针会自动解引用为它所包裹的值,然后调用该值对应的类型上定义的 `draw` 方法 x.draw(); } fn draw2(x: &dyn Draw) { x.draw(); } fn main() { let x = 1.1f64; // do_something(&x); let y = 8u8; // x 和 y 的类型 T 都实现了 `Draw` 特征,因为 Box<T> 可以在函数调用时隐式地被转换为特征对象 Box<dyn Draw> // 基于 x 的值创建一个 Box<f64> 类型的智能指针,指针指向的数据被放置在了堆上 draw1(Box::new(x)); // 基于 y 的值创建一个 Box<u8> 类型的智能指针 draw1(Box::new(y)); draw2(&x); draw2(&y); }
上面代码,有几个非常重要的点:
draw1函数的参数是Box<dyn Draw>形式的特征对象,该特征对象是通过Box::new(x)的方式创建的draw2函数的参数是&dyn Draw形式的特征对象,该特征对象是通过&x的方式创建的dyn关键字只用在特征对象的类型声明上,在创建时无需使用dyn
因此,可以使用特征对象来代表泛型或具体的类型。
继续来完善之前的 UI 组件代码,首先来实现 Screen:
#![allow(unused)] fn main() { pub struct Screen { pub components: Vec<Box<dyn Draw>>, } }
其中存储了一个动态数组,里面元素的类型是 Draw 特征对象:Box<dyn Draw>,任何实现了 Draw 特征的类型,都可以存放其中。
再来为 Screen 定义 run 方法,用于将列表中的 UI 组件渲染在屏幕上:
#![allow(unused)] fn main() { impl Screen { pub fn run(&self) { for component in self.components.iter() { component.draw(); } } } }
至此,我们就完成了之前的目标:在列表中存储多种不同类型的实例,然后将它们使用同一个方法逐一渲染在屏幕上!
再来看看,如果通过泛型实现,会如何:
#![allow(unused)] fn main() { pub struct Screen<T: Draw> { pub components: Vec<T>, } impl<T> Screen<T> where T: Draw { pub fn run(&self) { for component in self.components.iter() { component.draw(); } } } }
上面的 Screen 的列表中,存储了类型为 T 的元素,然后在 Screen 中使用特征约束让 T 实现了 Draw 特征,进而可以调用 draw 方法。
但是这种写法限制了 Screen 实例的 Vec<T> 中的每个元素必须是 Button 类型或者全是 SelectBox 类型。如果只需要同质(相同类型)集合,更倾向于这种写法:使用泛型和 特征约束,因为实现更清晰,且性能更好(特征对象,需要在运行时从 vtable 动态查找需要调用的方法)。
现在来运行渲染下咱们精心设计的 UI 组件列表:
fn main() { let screen = Screen { components: vec![ Box::new(SelectBox { width: 75, height: 10, options: vec![ String::from("Yes"), String::from("Maybe"), String::from("No") ], }), Box::new(Button { width: 50, height: 10, label: String::from("OK"), }), ], }; screen.run(); }
上面使用 Box::new(T) 的方式来创建了两个 Box<dyn Draw> 特征对象,如果以后还需要增加一个 UI 组件,那么让该组件实现 Draw 特征,则可以很轻松的将其渲染在屏幕上,甚至用户可以引入我们的库作为三方库,然后在自己的库中为自己的类型实现 Draw 特征,然后进行渲染。
在动态类型语言中,有一个很重要的概念:鸭子类型(duck typing),简单来说,就是只关心值长啥样,而不关心它实际是什么。当一个东西走起来像鸭子,叫起来像鸭子,那么它就是一只鸭子,就算它实际上是一个奥特曼,也不重要,我们就当它是鸭子。
在上例中,Screen 在 run 的时候,我们并不需要知道各个组件的具体类型是什么。它也不检查组件到底是 Button 还是 SelectBox 的实例,只要它实现了 Draw 特征,就能通过 Box::new 包装成 Box<dyn Draw> 特征对象,然后被渲染在屏幕上。
使用特征对象和 Rust 类型系统来进行类似鸭子类型操作的优势是,无需在运行时检查一个值是否实现了特定方法或者担心在调用时因为值没有实现方法而产生错误。如果值没有实现特征对象所需的特征, 那么 Rust 根本就不会编译这些代码:
fn main() { let screen = Screen { components: vec![ Box::new(String::from("Hi")), ], }; screen.run(); }
因为 String 类型没有实现 Draw 特征,编译器直接就会报错,不会让上述代码运行。如果想要 String 类型被渲染在屏幕上,那么只需要为其实现 Draw 特征即可,非常容易。
注意 dyn 不能单独作为特征对象的定义,例如下面的代码编译器会报错,原因是特征对象可以是任意实现了某个特征的类型,编译器在编译期不知道该类型的大小,不同的类型大小是不同的。
而 &dyn 和 Box<dyn> 在编译期都是已知大小,所以可以用作特征对象的定义。
#![allow(unused)] fn main() { fn draw2(x: dyn Draw) { x.draw(); } }
10 | fn draw2(x: dyn Draw) {
| ^ doesn't have a size known at compile-time
|
= help: the trait `Sized` is not implemented for `(dyn Draw + 'static)`
help: function arguments must have a statically known size, borrowed types always have a known size
特征对象的动态分发
回忆一下泛型章节我们提到过的,泛型是在编译期完成处理的:编译器会为每一个泛型参数对应的具体类型生成一份代码,这种方式是静态分发(static dispatch),因为是在编译期完成的,对于运行期性能完全没有任何影响。
与静态分发相对应的是动态分发(dynamic dispatch),在这种情况下,直到运行时,才能确定需要调用什么方法。之前代码中的关键字 dyn 正是在强调这一“动态”的特点。
当使用特征对象时,Rust 必须使用动态分发。编译器无法知晓所有可能用于特征对象代码的类型,所以它也不知道应该调用哪个类型的哪个方法实现。为此,Rust 在运行时使用特征对象中的指针来知晓需要调用哪个方法。动态分发也阻止编译器有选择的内联方法代码,这会相应的禁用一些优化。
下面这张图很好的解释了静态分发 Box<T> 和动态分发 Box<dyn Trait> 的区别:
结合上文的内容和这张图可以了解:
- 特征对象大小不固定:这是因为,对于特征
Draw,类型Button可以实现特征Draw,类型SelectBox也可以实现特征Draw,因此特征没有固定大小 - 几乎总是使用特征对象的引用方式,如
&dyn Draw、Box<dyn Draw>- 虽然特征对象没有固定大小,但它的引用类型的大小是固定的,它由两个指针组成(
ptr和vptr),因此占用两个指针大小 - 一个指针
ptr指向实现了特征Draw的具体类型的实例,也就是当作特征Draw来用的类型的实例,比如类型Button的实例、类型SelectBox的实例 - 另一个指针
vptr指向一个虚表vtable,vtable中保存了类型Button或类型SelectBox的实例对于可以调用的实现于特征Draw的方法。当调用方法时,直接从vtable中找到方法并调用。之所以要使用一个vtable来保存各实例的方法,是因为实现了特征Draw的类型有多种,这些类型拥有的方法各不相同,当将这些类型的实例都当作特征Draw来使用时(此时,它们全都看作是特征Draw类型的实例),有必要区分这些实例各自有哪些方法可调用
- 虽然特征对象没有固定大小,但它的引用类型的大小是固定的,它由两个指针组成(
简而言之,当类型 Button 实现了特征 Draw 时,类型 Button 的实例对象 btn 可以当作特征 Draw 的特征对象类型来使用,btn 中保存了作为特征对象的数据指针(指向类型 Button 的实例数据)和行为指针(指向 vtable)。
一定要注意,此时的 btn 是 Draw 的特征对象的实例,而不再是具体类型 Button 的实例,而且 btn 的 vtable 只包含了实现自特征 Draw 的那些方法(比如 draw),因此 btn 只能调用实现于特征 Draw 的 draw 方法,而不能调用类型 Button 本身实现的方法和类型 Button 实现于其他特征的方法。也就是说,btn 是哪个特征对象的实例,它的 vtable 中就包含了该特征的方法。
Self 与 self
在 Rust 中,有两个self,一个指代当前的实例对象,一个指代特征或者方法类型的别名:
trait Draw { fn draw(&self) -> Self; } #[derive(Clone)] struct Button; impl Draw for Button { fn draw(&self) -> Self { return self.clone() } } fn main() { let button = Button; let newb = button.draw(); }
上述代码中,self指代的就是当前的实例对象,也就是 button.draw() 中的 button 实例,Self 则指代的是 Button 类型。
当理解了 self 与 Self 的区别后,我们再来看看何为对象安全。
特征对象的限制
不是所有特征都能拥有特征对象,只有对象安全的特征才行。当一个特征的所有方法都有如下属性时,它的对象才是安全的:
- 方法的返回类型不能是
Self - 方法没有任何泛型参数
对象安全对于特征对象是必须的,因为一旦有了特征对象,就不再需要知道实现该特征的具体类型是什么了。如果特征方法返回了具体的 Self 类型,但是特征对象忘记了其真正的类型,那这个 Self 就非常尴尬,因为没人知道它是谁了。但是对于泛型类型参数来说,当使用特征时其会放入具体的类型参数:此具体类型变成了实现该特征的类型的一部分。而当使用特征对象时其具体类型被抹去了,故而无从得知放入泛型参数类型到底是什么。
标准库中的 Clone 特征就不符合对象安全的要求:
#![allow(unused)] fn main() { pub trait Clone { fn clone(&self) -> Self; } }
因为它的其中一个方法,返回了 Self 类型,因此它是对象不安全的。
String 类型实现了 Clone 特征, String 实例上调用 clone 方法时会得到一个 String 实例。类似的,当调用 Vec<T> 实例的 clone 方法会得到一个 Vec<T> 实例。clone 的签名需要知道什么类型会代替 Self,因为这是它的返回值。
如果违反了对象安全的规则,编译器会提示你。例如,如果尝试使用之前的 Screen 结构体来存放实现了 Clone 特征的类型:
#![allow(unused)] fn main() { pub struct Screen { pub components: Vec<Box<dyn Clone>>, } }
将会得到如下错误:
error[E0038]: the trait `std::clone::Clone` cannot be made into an object
--> src/lib.rs:2:5
|
2 | pub components: Vec<Box<dyn Clone>>,
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `std::clone::Clone`
cannot be made into an object
|
= note: the trait cannot require that `Self : Sized`
这意味着不能以这种方式使用此特征作为特征对象。
深入了解特征
特征之于 Rust 更甚于接口之于其他语言,因此特征在 Rust 中很重要也相对较为复杂,我们决定把特征分为两篇进行介绍,第一篇在之前已经讲过,现在就是第二篇:关于特征的进阶篇,会讲述一些不常用到但是你该了解的特性。
关联类型
在方法一章中,我们讲到了关联函数,但是实际上关联类型和关联函数并没有任何交集,虽然它们的名字有一半的交集。
关联类型是在特征定义的语句块中,申明一个自定义类型,这样就可以在特征的方法签名中使用该类型:
#![allow(unused)] fn main() { pub trait Iterator { type Item; fn next(&mut self) -> Option<Self::Item>; } }
以上是标准库中的迭代器特征 Iterator,它有一个 Item 关联类型,用于替代遍历的值的类型。
同时,next 方法也返回了一个 Item 类型,不过使用 Option 枚举进行了包裹,假如迭代器中的值是 i32 类型,那么调用 next 方法就将获取一个 Option<i32> 的值。
还记得 Self 吧?在之前的章节提到过, Self 用来指代当前调用者的具体类型,那么 Self::Item 就用来指代该类型实现中定义的 Item 类型:
impl Iterator for Counter { type Item = u32; fn next(&mut self) -> Option<Self::Item> { // --snip-- } } fn main() { let c = Counter{..} c.next() }
在上述代码中,我们为 Counter 类型实现了 Iterator 特征,变量 c 是特征 Iterator 的实例,也是 next 方法的调用者。 结合之前的黑体内容可以得出:对于 next 方法而言,Self 是调用者 c 的具体类型: Counter,而 Self::Item 是 Counter 中定义的 Item 类型: u32。
聪明的读者之所以聪明,是因为你们喜欢联想和举一反三,同时你们也喜欢提问:为何不用泛型,例如如下代码:
#![allow(unused)] fn main() { pub trait Iterator<Item> { fn next(&mut self) -> Option<Item>; } }
答案其实很简单,为了代码的可读性,当你使用了泛型后,你需要在所有地方都写 Iterator<Item>,而使用了关联类型,你只需要写 Iterator,当类型定义复杂时,这种写法可以极大的增加可读性:
#![allow(unused)] fn main() { pub trait CacheableItem: Clone + Default + fmt::Debug + Decodable + Encodable { type Address: AsRef<[u8]> + Clone + fmt::Debug + Eq + Hash; fn is_null(&self) -> bool; } }
例如上面的代码,Address 的写法自然远比 AsRef<[u8]> + Clone + fmt::Debug + Eq + Hash 要简单的多,而且含义清晰。
再例如,如果使用泛型,你将得到以下的代码:
#![allow(unused)] fn main() { trait Container<A,B> { fn contains(&self,a: A,b: B) -> bool; } fn difference<A,B,C>(container: &C) -> i32 where C : Container<A,B> {...} }
可以看到,由于使用了泛型,导致函数头部也必须增加泛型的声明,而使用关联类型,将得到可读性好得多的代码:
#![allow(unused)] fn main() { trait Container{ type A; type B; fn contains(&self, a: &Self::A, b: &Self::B) -> bool; } fn difference<C: Container>(container: &C) {} }
默认泛型类型参数
当使用泛型类型参数时,可以为其指定一个默认的具体类型,例如标准库中的 std::ops::Add 特征:
#![allow(unused)] fn main() { trait Add<RHS=Self> { type Output; fn add(self, rhs: RHS) -> Self::Output; } }
它有一个泛型参数 RHS,但是与我们以往的用法不同,这里它给 RHS 一个默认值,也就是当用户不指定 RHS 时,默认使用两个同样类型的值进行相加,然后返回一个关联类型 Output。
可能上面那段不太好理解,下面我们用代码来举例:
use std::ops::Add; #[derive(Debug, PartialEq)] struct Point { x: i32, y: i32, } impl Add for Point { type Output = Point; fn add(self, other: Point) -> Point { Point { x: self.x + other.x, y: self.y + other.y, } } } fn main() { assert_eq!(Point { x: 1, y: 0 } + Point { x: 2, y: 3 }, Point { x: 3, y: 3 }); }
上面的代码主要干了一件事,就是为 Point 结构体提供 + 的能力,这就是运算符重载,不过 Rust 并不支持创建自定义运算符,你也无法为所有运算符进行重载,目前来说,只有定义在 std::ops 中的运算符才能进行重载。
跟 + 对应的特征是 std::ops::Add,我们在之前也看过它的定义 trait Add<RHS=Self>,但是上面的例子中并没有为 Point 实现 Add<RHS> 特征,而是实现了 Add 特征(没有默认泛型类型参数),这意味着我们使用了 RHS 的默认类型,也就是 Self。换句话说,我们这里定义的是两个相同的 Point 类型相加,因此无需指定 RHS。
与上面的例子相反,下面的例子,我们来创建两个不同类型的相加:
#![allow(unused)] fn main() { use std::ops::Add; struct Millimeters(u32); struct Meters(u32); impl Add<Meters> for Millimeters { type Output = Millimeters; fn add(self, other: Meters) -> Millimeters { Millimeters(self.0 + (other.0 * 1000)) } } }
这里,是进行 Millimeters + Meters 两种数据类型的 + 操作,因此此时不能再使用默认的 RHS,否则就会变成 Millimeters + Millimeters 的形式。使用 Add<Meters> 可以将 RHS 指定为 Meters,那么 fn add(self, rhs: RHS) 自然而言的变成了 Millimeters 和 Meters 的相加。
默认类型参数主要用于两个方面:
- 减少实现的样板代码
- 扩展类型但是无需大幅修改现有的代码
之前的例子就是第一点,虽然效果也就那样。在 + 左右两边都是同样类型时,只需要 impl Add 即可,否则你需要 impl Add<SOME_TYPE>,嗯,会多写几个字:)
对于第二点,也很好理解,如果你在一个复杂类型的基础上,新引入一个泛型参数,可能需要修改很多地方,但是如果新引入的泛型参数有了默认类型,情况就会好很多,添加泛型参数后,使用这个类型的代码需要逐个在类型提示部分添加泛型参数,就很麻烦;但是有了默认参数(且默认参数取之前的实现里假设的值的情况下)之后,原有的使用这个类型的代码就不需要做改动了。
归根到底,默认泛型参数,是有用的,但是大多数情况下,咱们确实用不到,当需要用到时,大家再回头来查阅本章即可,手上有剑,心中不慌。
调用同名的方法
不同特征拥有同名的方法是很正常的事情,你没有任何办法阻止这一点;甚至除了特征上的同名方法外,在你的类型上,也有同名方法:
#![allow(unused)] fn main() { trait Pilot { fn fly(&self); } trait Wizard { fn fly(&self); } struct Human; impl Pilot for Human { fn fly(&self) { println!("This is your captain speaking."); } } impl Wizard for Human { fn fly(&self) { println!("Up!"); } } impl Human { fn fly(&self) { println!("*waving arms furiously*"); } } }
这里,不仅仅两个特征 Pilot 和 Wizard 有 fly 方法,就连实现那两个特征的 Human 单元结构体,也拥有一个同名方法 fly (这世界怎么了,非要这么卷吗?程序员何苦难为程序员,哎)。
既然代码已经不可更改,那下面我们来讲讲该如何调用这些 fly 方法。
优先调用类型上的方法
当调用 Human 实例的 fly 时,编译器默认调用该类型中定义的方法:
fn main() { let person = Human; person.fly(); }
这段代码会打印 *waving arms furiously*,说明直接调用了类型上定义的方法。
调用特征上的方法
为了能够调用两个特征的方法,需要使用显式调用的语法:
fn main() { let person = Human; Pilot::fly(&person); // 调用Pilot特征上的方法 Wizard::fly(&person); // 调用Wizard特征上的方法 person.fly(); // 调用Human类型自身的方法 }
运行后依次输出:
This is your captain speaking.
Up!
*waving arms furiously*
因为 fly 方法的参数是 self,当显式调用时,编译器就可以根据调用的类型( self 的类型)决定具体调用哪个方法。
这个时候问题又来了,如果方法没有 self 参数呢?稍等,估计有读者会问:还有方法没有 self 参数?看到这个疑问,作者的眼泪不禁流了下来,大明湖畔的关联函数,你还记得嘛?
但是成年人的世界,就算再伤心,事还得做,咱们继续:
trait Animal { fn baby_name() -> String; } struct Dog; impl Dog { fn baby_name() -> String { String::from("Spot") } } impl Animal for Dog { fn baby_name() -> String { String::from("puppy") } } fn main() { println!("A baby dog is called a {}", Dog::baby_name()); }
就像人类妈妈会给自己的宝宝起爱称一样,狗妈妈也会。狗妈妈称呼自己的宝宝为Spot,其它动物称呼狗宝宝为puppy,这个时候假如有动物不知道该如何称呼狗宝宝,它需要查询一下。
Dog::baby_name() 的调用方式显然不行,因为这只是狗妈妈对宝宝的爱称,可能你会想到通过下面的方式查询其他动物对狗狗的称呼:
fn main() { println!("A baby dog is called a {}", Animal::baby_name()); }
铛铛,无情报错了:
#![allow(unused)] fn main() { error[E0283]: type annotations needed // 需要类型注释 --> src/main.rs:20:43 | 20 | println!("A baby dog is called a {}", Animal::baby_name()); | ^^^^^^^^^^^^^^^^^ cannot infer type // 无法推断类型 | = note: cannot satisfy `_: Animal` }
因为单纯从 Animal::baby_name() 上,编译器无法得到任何有效的信息:实现 Animal 特征的类型可能有很多,你究竟是想获取哪个动物宝宝的名称?狗宝宝?猪宝宝?还是熊宝宝?
此时,就需要使用完全限定语法。
完全限定语法
完全限定语法是调用函数最为明确的方式:
fn main() { println!("A baby dog is called a {}", <Dog as Animal>::baby_name()); }
在尖括号中,通过 as 关键字,我们向 Rust 编译器提供了类型注解,也就是 Animal 就是 Dog,而不是其他动物,因此最终会调用 impl Animal for Dog 中的方法,获取到其它动物对狗宝宝的称呼:puppy。
言归正题,完全限定语法定义为:
#![allow(unused)] fn main() { <Type as Trait>::function(receiver_if_method, next_arg, ...); }
上面定义中,第一个参数是方法接收器 receiver (三种 self),只有方法才拥有,例如关联函数就没有 receiver。
完全限定语法可以用于任何函数或方法调用,那么我们为何很少用到这个语法?原因是 Rust 编译器能根据上下文自动推导出调用的路径,因此大多数时候,我们都无需使用完全限定语法。只有当存在多个同名函数或方法,且 Rust 无法区分出你想调用的目标函数时,该用法才能真正有用武之地。
特征定义中的特征约束
有时,我们会需要让某个特征 A 能使用另一个特征 B 的功能(另一种形式的特征约束),这种情况下,不仅仅要为类型实现特征 A,还要为类型实现特征 B 才行,这就是 supertrait (实在不知道该如何翻译,有大佬指导下嘛?)
例如有一个特征 OutlinePrint,它有一个方法,能够对当前的实现类型进行格式化输出:
#![allow(unused)] fn main() { use std::fmt::Display; trait OutlinePrint: Display { fn outline_print(&self) { let output = self.to_string(); let len = output.len(); println!("{}", "*".repeat(len + 4)); println!("*{}*", " ".repeat(len + 2)); println!("* {} *", output); println!("*{}*", " ".repeat(len + 2)); println!("{}", "*".repeat(len + 4)); } } }
等等,这里有一个眼熟的语法: OutlinePrint: Display,感觉很像之前讲过的特征约束,只不过用在了特征定义中而不是函数的参数中,是的,在某种意义上来说,这和特征约束非常类似,都用来说明一个特征需要实现另一个特征,这里就是:如果你想要实现 OutlinePrint 特征,首先你需要实现 Display 特征。
想象一下,假如没有这个特征约束,那么 self.to_string 还能够调用吗( to_string 方法会为实现 Display 特征的类型自动实现)?编译器肯定是不愿意的,会报错说当前作用域中找不到用于 &Self 类型的方法 to_string :
#![allow(unused)] fn main() { struct Point { x: i32, y: i32, } impl OutlinePrint for Point {} }
因为 Point 没有实现 Display 特征,会得到下面的报错:
error[E0277]: the trait bound `Point: std::fmt::Display` is not satisfied
--> src/main.rs:20:6
|
20 | impl OutlinePrint for Point {}
| ^^^^^^^^^^^^ `Point` cannot be formatted with the default formatter;
try using `:?` instead if you are using a format string
|
= help: the trait `std::fmt::Display` is not implemented for `Point`
既然我们有求于编译器,那只能选择满足它咯:
#![allow(unused)] fn main() { use std::fmt; impl fmt::Display for Point { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "({}, {})", self.x, self.y) } } }
上面代码为 Point 实现了 Display 特征,那么 to_string 方法也将自动实现:最终获得字符串是通过这里的 fmt 方法获得的。
在外部类型上实现外部特征(newtype)
在特征章节中,有提到孤儿规则,简单来说,就是特征或者类型必需至少有一个是本地的,才能在此类型上定义特征。
这里提供一个办法来绕过孤儿规则,那就是使用newtype 模式,简而言之:就是为一个元组结构体创建新类型。该元组结构体封装有一个字段,该字段就是希望实现特征的具体类型。
该封装类型是本地的,因此我们可以为此类型实现外部的特征。
newtype 不仅仅能实现以上的功能,而且它在运行时没有任何性能损耗,因为在编译期,该类型会被自动忽略。
下面来看一个例子,我们有一个动态数组类型: Vec<T>,它定义在标准库中,还有一个特征 Display,它也定义在标准库中,如果没有 newtype,我们是无法为 Vec<T> 实现 Display 的:
error[E0117]: only traits defined in the current crate can be implemented for arbitrary types
--> src/main.rs:5:1
|
5 | impl<T> std::fmt::Display for Vec<T> {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^------
| | |
| | Vec is not defined in the current crate
| impl doesn't use only types from inside the current crate
|
= note: define and implement a trait or new type instead
编译器给了我们提示: define and implement a trait or new type instead,重新定义一个特征,或者使用 new type,前者当然不可行,那么来试试后者:
use std::fmt; struct Wrapper(Vec<String>); impl fmt::Display for Wrapper { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "[{}]", self.0.join(", ")) } } fn main() { let w = Wrapper(vec![String::from("hello"), String::from("world")]); println!("w = {}", w); }
其中,struct Wrapper(Vec<String>) 就是一个元组结构体,它定义了一个新类型 Wrapper,代码很简单,相信大家也很容易看懂。
既然 new type 有这么多好处,它有没有不好的地方呢?答案是肯定的。注意到我们怎么访问里面的数组吗?self.0.join(", "),是的,很啰嗦,因为需要先从 Wrapper 中取出数组: self.0,然后才能执行 join 方法。
类似的,任何数组上的方法,你都无法直接调用,需要先用 self.0 取出数组,然后再进行调用。
当然,解决办法还是有的,要不怎么说 Rust 是极其强大灵活的编程语言!Rust 提供了一个特征叫 Deref,实现该特征后,可以自动做一层类似类型转换的操作,可以将 Wrapper 变成 Vec<String> 来使用。这样就会像直接使用数组那样去使用 Wrapper,而无需为每一个操作都添加上 self.0。
同时,如果不想 Wrapper 暴露底层数组的所有方法,我们还可以为 Wrapper 去重载这些方法,实现隐藏的目的。
集合类型
动态数组 Vector
动态数组类型用 Vec<T> 表示,事实上,在之前的章节,它的身影多次出现,我们一直没有细讲,只是简单的把它当作数组处理。
动态数组允许你存储多个值,这些值在内存中一个紧挨着另一个排列,因此访问其中某个元素的成本非常低。动态数组只能存储相同类型的元素,如果你想存储不同类型的元素,可以使用之前讲过的枚举类型或者特征对象。
总之,当我们想拥有一个列表,里面都是相同类型的数据时,动态数组将会非常有用。
创建动态数组
在 Rust 中,有多种方式可以创建动态数组。
Vec::new
使用 Vec::new 创建动态数组是最 rusty 的方式,它调用了 Vec 中的 new 关联函数:
#![allow(unused)] fn main() { let v: Vec<i32> = Vec::new(); }
这里,v 被显式地声明了类型 Vec<i32>,这是因为 Rust 编译器无法从 Vec::new() 中得到任何关于类型的暗示信息,因此也无法推导出 v 的具体类型,但是当你向里面增加一个元素后,一切又不同了:
#![allow(unused)] fn main() { let mut v = Vec::new(); v.push(1); }
此时,v 就无需手动声明类型,因为编译器通过 v.push(1),推测出 v 中的元素类型是 i32,因此推导出 v 的类型是 Vec<i32>。
如果预先知道要存储的元素个数,可以使用
Vec::with_capacity(capacity)创建动态数组,这样可以避免因为插入大量新数据导致频繁的内存分配和拷贝,提升性能
vec![]
还可以使用宏 vec! 来创建数组,与 Vec::new 有所不同,前者能在创建同时给予初始化值:
#![allow(unused)] fn main() { let v = vec![1, 2, 3]; }
同样,此处的 v 也无需标注类型,编译器只需检查它内部的元素即可自动推导出 v 的类型是 Vec<i32> (Rust 中,整数默认类型是 i32,在数值类型中有详细介绍)。
更新 Vector
向数组尾部添加元素,可以使用 push 方法:
#![allow(unused)] fn main() { let mut v = Vec::new(); v.push(1); }
与其它类型一样,必须将 v 声明为 mut 后,才能进行修改。
Vector 与其元素共存亡
跟结构体一样,Vector 类型在超出作用域范围后,会被自动删除:
#![allow(unused)] fn main() { { let v = vec![1, 2, 3]; // ... } // <- v超出作用域并在此处被删除 }
当 Vector 被删除后,它内部存储的所有内容也会随之被删除。目前来看,这种解决方案简单直白,但是当 Vector 中的元素被引用后,事情可能会没那么简单。
从 Vector 中读取元素
读取指定位置的元素有两种方式可选:
- 通过下标索引访问。
- 使用
get方法。
#![allow(unused)] fn main() { let v = vec![1, 2, 3, 4, 5]; let third: &i32 = &v[2]; println!("第三个元素是 {}", third); match v.get(2) { Some(third) => println!("第三个元素是 {third}"), None => println!("去你的第三个元素,根本没有!"), } }
和其它语言一样,集合类型的索引下标都是从 0 开始,&v[2] 表示借用 v 中的第三个元素,最终会获得该元素的引用。而 v.get(2) 也是访问第三个元素,但是有所不同的是,它返回了 Option<&T>,因此还需要额外的 match 来匹配解构出具体的值。
细心的同学会注意到这里使用了两种格式化输出的方式,其中第一种我们在之前已经见过,而第二种是后续新版本中引入的写法,也是更推荐的用法,具体介绍请参见格式化输出章节
下标索引与 .get 的区别
这两种方式都能成功的读取到指定的数组元素,既然如此为什么会存在两种方法?何况 .get 还会增加使用复杂度,这就涉及到数组越界的问题了,让我们通过示例说明:
#![allow(unused)] fn main() { let v = vec![1, 2, 3, 4, 5]; let does_not_exist = &v[100]; let does_not_exist = v.get(100); }
运行以上代码,&v[100] 的访问方式会导致程序无情报错退出,因为发生了数组越界访问。 但是 v.get 就不会,它在内部做了处理,有值的时候返回 Some(T),无值的时候返回 None,因此 v.get 的使用方式非常安全。
既然如此,为何不统一使用 v.get 的形式?因为实在是有些啰嗦,Rust 语言的设计者和使用者在审美这方面还是相当统一的:简洁即正义,何况性能上也会有轻微的损耗。
既然有两个选择,肯定就有如何选择的问题,答案很简单,当你确保索引不会越界的时候,就用索引访问,否则用 .get。例如,访问第几个数组元素并不取决于我们,而是取决于用户的输入时,用 .get 会非常适合,天知道那些可爱的用户会输入一个什么样的数字进来!
同时借用多个数组元素
既然涉及到借用数组元素,那么很可能会遇到同时借用多个数组元素的情况,还记得在所有权和借用章节咱们讲过的借用规则嘛?如果记得,就来看看下面的代码 :)
#![allow(unused)] fn main() { let mut v = vec![1, 2, 3, 4, 5]; let first = &v[0]; v.push(6); println!("The first element is: {first}"); }
先不运行,来推断下结果,首先 first = &v[0] 进行了不可变借用,v.push 进行了可变借用,如果 first 在 v.push 之后不再使用,那么该段代码可以成功编译。
可是上面的代码中,first 这个不可变借用在可变借用 v.push 后被使用了,那么妥妥的,编译器就会报错:
$ cargo run
Compiling collections v0.1.0 (file:///projects/collections)
error[E0502]: cannot borrow `v` as mutable because it is also borrowed as immutable 无法对v进行可变借用,因此之前已经进行了不可变借用
--> src/main.rs:6:5
|
4 | let first = &v[0];
| - immutable borrow occurs here // 不可变借用发生在此处
5 |
6 | v.push(6);
| ^^^^^^^^^ mutable borrow occurs here // 可变借用发生在此处
7 |
8 | println!("The first element is: {}", first);
| ----- immutable borrow later used here // 不可变借用在这里被使用
For more information about this error, try `rustc --explain E0502`.
error: could not compile `collections` due to previous error
其实,按理来说,这两个引用不应该互相影响的:一个是查询元素,一个是在数组尾部插入元素,完全不相干的操作,为何编译器要这么严格呢?
原因在于:数组的大小是可变的,当旧数组的大小不够用时,Rust 会重新分配一块更大的内存空间,然后把旧数组拷贝过来。这种情况下,之前的引用显然会指向一块无效的内存,这非常 rusty —— 对用户进行严格的教育。
其实想想,在长大之后,我们感激人生路上遇到过的严师益友,正是因为他们,我们才在正确的道路上不断前行,虽然在那个时候,并不能理解他们,而 Rust 就如那个良师益友,它不断的在纠正我们不好的编程习惯,直到某一天,你发现自己能写出一次性通过的漂亮代码时,就能明白它的良苦用心。
若读者想要更深入的了解
Vec<T>,可以看看Rustonomicon,其中从零手撸一个动态数组,非常适合深入学习
迭代遍历 Vector 中的元素
如果想要依次访问数组中的元素,可以使用迭代的方式去遍历数组,这种方式比用下标的方式去遍历数组更安全也更高效(每次下标访问都会触发数组边界检查):
#![allow(unused)] fn main() { let v = vec![1, 2, 3]; for i in &v { println!("{i}"); } }
也可以在迭代过程中,修改 Vector 中的元素:
#![allow(unused)] fn main() { let mut v = vec![1, 2, 3]; for i in &mut v { *i += 10 } }
存储不同类型的元素
在本节开头,有讲到数组的元素必须类型相同,但是也提到了解决方案:那就是通过使用枚举类型和特征对象来实现不同类型元素的存储。先来看看通过枚举如何实现:
#[derive(Debug)] enum IpAddr { V4(String), V6(String) } fn main() { let v = vec![ IpAddr::V4("127.0.0.1".to_string()), IpAddr::V6("::1".to_string()) ]; for ip in v { show_addr(ip) } } fn show_addr(ip: IpAddr) { println!("{:?}",ip); }
数组 v 中存储了两种不同的 ip 地址,但是这两种都属于 IpAddr 枚举类型的成员,因此可以存储在数组中。
再来看看特征对象的实现:
trait IpAddr { fn display(&self); } struct V4(String); impl IpAddr for V4 { fn display(&self) { println!("ipv4: {:?}",self.0) } } struct V6(String); impl IpAddr for V6 { fn display(&self) { println!("ipv6: {:?}",self.0) } } fn main() { let v: Vec<Box<dyn IpAddr>> = vec![ Box::new(V4("127.0.0.1".to_string())), Box::new(V6("::1".to_string())), ]; for ip in v { ip.display(); } }
比枚举实现要稍微复杂一些,我们为 V4 和 V6 都实现了特征 IpAddr,然后将它俩的实例用 Box::new 包裹后,存在了数组 v 中,需要注意的是,这里必须手动地指定类型:Vec<Box<dyn IpAddr>>,表示数组 v 存储的是特征 IpAddr 的对象,这样就实现了在数组中存储不同的类型。
在实际使用场景中,特征对象数组要比枚举数组常见很多,主要原因在于特征对象非常灵活,而编译器对枚举的限制较多,且无法动态增加类型。
KV 存储 HashMap
和动态数组一样,HashMap 也是 Rust 标准库中提供的集合类型,但是又与动态数组不同,HashMap 中存储的是一一映射的 KV 键值对,并提供了平均复杂度为 O(1) 的查询方法,当我们希望通过一个 Key 去查询值时,该类型非常有用,以致于 Go 语言将该类型设置成了语言级别的内置特性。
Rust 中哈希类型(哈希映射)为 HashMap<K,V>,在其它语言中,也有类似的数据结构,例如 hash map,map,object,hash table,字典 等等,引用小品演员孙涛的一句台词:大家都是本地狐狸,别搁那装貂 :)。
创建 HashMap
跟创建动态数组 Vec 的方法类似,可以使用 new 方法来创建 HashMap,然后通过 insert 方法插入键值对。
使用 new 方法创建
#![allow(unused)] fn main() { use std::collections::HashMap; // 创建一个HashMap,用于存储宝石种类和对应的数量 let mut my_gems = HashMap::new(); // 将宝石类型和对应的数量写入表中 my_gems.insert("红宝石", 1); my_gems.insert("蓝宝石", 2); my_gems.insert("河边捡的误以为是宝石的破石头", 18); }
很简单对吧?跟其它语言没有区别,聪明的同学甚至能够猜到该 HashMap 的类型:HashMap<&str,i32>。
但是还有一点,你可能没有注意,那就是使用 HashMap 需要手动通过 use ... 从标准库中引入到我们当前的作用域中来,仔细回忆下,之前使用另外两个集合类型 String 和 Vec 时,我们是否有手动引用过?答案是 No,因为 HashMap 并没有包含在 Rust 的 prelude中(Rust 为了简化用户使用,提前将最常用的类型自动引入到作用域中)。
所有的集合类型都是动态的,意味着它们没有固定的内存大小,因此它们底层的数据都存储在内存堆上,然后通过一个存储在栈中的引用类型来访问。同时,跟其它集合类型一致,HashMap 也是内聚性的,即所有的 K 必须拥有同样的类型,V 也是如此。
跟
Vec一样,如果预先知道要存储的KV对个数,可以使用HashMap::with_capacity(capacity)创建指定大小的HashMap,避免频繁的内存分配和拷贝,提升性能
使用迭代器和 collect 方法创建
在实际使用中,不是所有的场景都能 new 一个哈希表后,然后悠哉悠哉的依次插入对应的键值对,而是可能会从另外一个数据结构中,获取到对应的数据,最终生成 HashMap。
例如考虑一个场景,有一张表格中记录了足球联赛中各队伍名称和积分的信息,这张表如果被导入到 Rust 项目中,一个合理的数据结构是 Vec<(String, u32)> 类型,该数组中的元素是一个个元组,该数据结构跟表格数据非常契合:表格中的数据都是逐行存储,每一个行都存有一个 (队伍名称, 积分) 的信息。
但是在很多时候,又需要通过队伍名称来查询对应的积分,此时动态数组就不适用了,因此可以用 HashMap 来保存相关的队伍名称 -> 积分映射关系。 理想很丰满,现实很骨感,如何将 Vec<(String, u32)> 中的数据快速写入到 HashMap<String, u32> 中?
一个动动脚趾头就能想到的笨方法如下:
fn main() { use std::collections::HashMap; let teams_list = vec![ ("中国队".to_string(), 100), ("美国队".to_string(), 10), ("日本队".to_string(), 50), ]; let mut teams_map = HashMap::new(); for team in &teams_list { teams_map.insert(&team.0, team.1); } println!("{:?}",teams_map) }
遍历列表,将每一个元组作为一对 KV 插入到 HashMap 中,很简单,但是……也不太聪明的样子,换个词说就是 —— 不够 rusty。
好在,Rust 为我们提供了一个非常精妙的解决办法:先将 Vec 转为迭代器,接着通过 collect 方法,将迭代器中的元素收集后,转成 HashMap:
fn main() { use std::collections::HashMap; let teams_list = vec![ ("中国队".to_string(), 100), ("美国队".to_string(), 10), ("日本队".to_string(), 50), ]; let teams_map: HashMap<_,_> = teams_list.into_iter().collect(); println!("{:?}",teams_map) }
代码很简单,into_iter 方法将列表转为迭代器,接着通过 collect 进行收集,不过需要注意的是,collect 方法在内部实际上支持生成多种类型的目标集合,因此我们需要通过类型标注 HashMap<_,_> 来告诉编译器:请帮我们收集为 HashMap 集合类型,具体的 KV 类型,麻烦编译器您老人家帮我们推导。
由此可见,Rust 中的编译器时而小聪明,时而大聪明,不过好在,它大聪明的时候,会自家人知道自己事,总归会通知你一声:
error[E0282]: type annotations needed // 需要类型标注
--> src/main.rs:10:9
|
10 | let teams_map = teams_list.into_iter().collect();
| ^^^^^^^^^ consider giving `teams_map` a type // 给予 `teams_map` 一个具体的类型
所有权转移
HashMap 的所有权规则与其它 Rust 类型没有区别:
- 若类型实现
Copy特征,该类型会被复制进HashMap,因此无所谓所有权 - 若没实现
Copy特征,所有权将被转移给HashMap中
例如我参选帅气男孩时的场景再现:
fn main() { use std::collections::HashMap; let name = String::from("Sunface"); let age = 18; let mut handsome_boys = HashMap::new(); handsome_boys.insert(name, age); println!("因为过于无耻,{}已经被从帅气男孩名单中除名", name); println!("还有,他的真实年龄远远不止{}岁", age); }
运行代码,报错如下:
error[E0382]: borrow of moved value: `name`
--> src/main.rs:10:32
|
4 | let name = String::from("Sunface");
| ---- move occurs because `name` has type `String`, which does not implement the `Copy` trait
...
8 | handsome_boys.insert(name, age);
| ---- value moved here
9 |
10 | println!("因为过于无耻,{}已经被除名", name);
| ^^^^ value borrowed here after move
提示很清晰,name 是 String 类型,因此它受到所有权的限制,在 insert 时,它的所有权被转移给 handsome_boys,所以最后在使用时,会遇到这个无情但是意料之中的报错。
如果你使用引用类型放入 HashMap 中,请确保该引用的生命周期至少跟 HashMap 活得一样久:
fn main() { use std::collections::HashMap; let name = String::from("Sunface"); let age = 18; let mut handsome_boys = HashMap::new(); handsome_boys.insert(&name, age); std::mem::drop(name); println!("因为过于无耻,{:?}已经被除名", handsome_boys); println!("还有,他的真实年龄远远不止{}岁", age); }
上面代码,我们借用 name 获取了它的引用,然后插入到 handsome_boys 中,至此一切都很完美。但是紧接着,就通过 drop 函数手动将 name 字符串从内存中移除,再然后就报错了:
handsome_boys.insert(&name, age);
| ----- borrow of `name` occurs here // name借用发生在此处
9 |
10 | std::mem::drop(name);
| ^^^^ move out of `name` occurs here // name的所有权被转移走
11 | println!("因为过于无耻,{:?}已经被除名", handsome_boys);
| ------------- borrow later used here // 所有权转移后,还试图使用name
最终,某人因为过于无耻,真正的被除名了 :)
查询 HashMap
通过 get 方法可以获取元素:
#![allow(unused)] fn main() { use std::collections::HashMap; let mut scores = HashMap::new(); scores.insert(String::from("Blue"), 10); scores.insert(String::from("Yellow"), 50); let team_name = String::from("Blue"); let score: Option<&i32> = scores.get(&team_name); }
上面有几点需要注意:
get方法返回一个Option<&i32>类型:当查询不到时,会返回一个None,查询到时返回Some(&i32)&i32是对HashMap中值的借用,如果不使用借用,可能会发生所有权的转移
还可以继续拓展下,上面的代码中,如果我们想直接获得值类型的 score 该怎么办,答案简约但不简单:
#![allow(unused)] fn main() { let score: i32 = scores.get(&team_name).copied().unwrap_or(0); }
这里留给大家一个小作业: 去官方文档中查询下 Option 的 copied 方法和 unwrap_or 方法的含义及该如何使用。
还可以通过循环的方式依次遍历 KV 对:
#![allow(unused)] fn main() { use std::collections::HashMap; let mut scores = HashMap::new(); scores.insert(String::from("Blue"), 10); scores.insert(String::from("Yellow"), 50); for (key, value) in &scores { println!("{}: {}", key, value); } }
最终输出:
Yellow: 50
Blue: 10
更新 HashMap 中的值
更新值的时候,涉及多种情况,咱们在代码中一一进行说明:
fn main() { use std::collections::HashMap; let mut scores = HashMap::new(); scores.insert("Blue", 10); // 覆盖已有的值 let old = scores.insert("Blue", 20); assert_eq!(old, Some(10)); // 查询新插入的值 let new = scores.get("Blue"); assert_eq!(new, Some(&20)); // 查询Yellow对应的值,若不存在则插入新值 let v = scores.entry("Yellow").or_insert(5); assert_eq!(*v, 5); // 不存在,插入5 // 查询Yellow对应的值,若不存在则插入新值 let v = scores.entry("Yellow").or_insert(50); assert_eq!(*v, 5); // 已经存在,因此50没有插入 }
具体的解释在代码注释中已有,这里不再进行赘述。
在已有值的基础上更新
另一个常用场景如下:查询某个 key 对应的值,若不存在则插入新值,若存在则对已有的值进行更新,例如在文本中统计词语出现的次数:
#![allow(unused)] fn main() { use std::collections::HashMap; let text = "hello world wonderful world"; let mut map = HashMap::new(); // 根据空格来切分字符串(英文单词都是通过空格切分) for word in text.split_whitespace() { let count = map.entry(word).or_insert(0); *count += 1; } println!("{:?}", map); }
上面代码中,新建一个 map 用于保存词语出现的次数,插入一个词语时会进行判断:若之前没有插入过,则使用该词语作 Key,插入次数 0 作为 Value,若之前插入过则取出之前统计的该词语出现的次数,对其加一。
有两点值得注意:
or_insert返回了&mut v引用,因此可以通过该可变引用直接修改map中对应的值- 使用
count引用时,需要先进行解引用*count,否则会出现类型不匹配
哈希函数
你肯定比较好奇,为何叫哈希表,到底什么是哈希。
先来设想下,如果要实现 Key 与 Value 的一一对应,是不是意味着我们要能比较两个 Key 的相等性?例如 "a" 和 "b",1 和 2,当这些类型做 Key 且能比较时,可以很容易知道 1 对应的值不会错误的映射到 2 上,因为 1 不等于 2。因此,一个类型能否作为 Key 的关键就是是否能进行相等比较,或者说该类型是否实现了 std::cmp::Eq 特征。
f32 和 f64 浮点数,没有实现
std::cmp::Eq特征,因此不可以用作HashMap的Key
好了,理解完这个,再来设想一点,若一个复杂点的类型作为 Key,那怎么在底层对它进行存储,怎么使用它进行查询和比较? 是不是很棘手?好在我们有哈希函数:通过它把 Key 计算后映射为哈希值,然后使用该哈希值来进行存储、查询、比较等操作。
但是问题又来了,如何保证不同 Key 通过哈希后的两个值不会相同?如果相同,那意味着我们使用不同的 Key,却查到了同一个结果,这种明显是错误的行为。
此时,就涉及到安全性跟性能的取舍了。
若要追求安全,尽可能减少冲突,同时防止拒绝服务(Denial of Service, DoS)攻击,就要使用密码学安全的哈希函数,HashMap 就是使用了这样的哈希函数。反之若要追求性能,就需要使用没有那么安全的算法。
高性能三方库
因此若性能测试显示当前标准库默认的哈希函数不能满足你的性能需求,就需要去 crates.io 上寻找其它的哈希函数实现,使用方法很简单:
#![allow(unused)] fn main() { use std::hash::BuildHasherDefault; use std::collections::HashMap; // 引入第三方的哈希函数 use twox_hash::XxHash64; // 指定HashMap使用第三方的哈希函数XxHash64 let mut hash: HashMap<_, _, BuildHasherDefault<XxHash64>> = Default::default(); hash.insert(42, "the answer"); assert_eq!(hash.get(&42), Some(&"the answer")); }
目前,
HashMap使用的哈希函数是SipHash,它的性能不是很高,但是安全性很高。SipHash在中等大小的Key上,性能相当不错,但是对于小型的Key(例如整数)或者大型Key(例如字符串)来说,性能还是不够好。若你需要极致性能,例如实现算法,可以考虑这个库:ahash
认识生命周期
生命周期,简而言之就是引用的有效作用域。在大多数时候,我们无需手动的声明生命周期,因为编译器可以自动进行推导,用类型来类比下:
- 就像编译器大部分时候可以自动推导类型 <-> 一样,编译器大多数时候也可以自动推导生命周期
- 在多种类型存在时,编译器往往要求我们手动标明类型 <-> 当多个生命周期存在,且编译器无法推导出某个引用的生命周期时,就需要我们手动标明生命周期
Rust 生命周期之所以难,是因为这个概念对于我们来说是全新的,没有其它编程语言的经验可以借鉴。当你觉得难的时候,不用过于担心,这个难对于所有人都是平等的,多点付出就能早点解决此拦路虎,同时本书也会尽力帮助大家减少学习难度(生命周期很可能是 Rust 中最难的部分)。
悬垂指针和生命周期
生命周期的主要作用是避免悬垂引用,它会导致程序引用了本不该引用的数据:
#![allow(unused)] fn main() { { let r; { let x = 5; r = &x; } println!("r: {}", r); } }
这段代码有几点值得注意:
let r;的声明方式貌似存在使用null的风险,实际上,当我们不初始化它就使用时,编译器会给予报错r引用了内部花括号中的x变量,但是x会在内部花括号}处被释放,因此回到外部花括号后,r会引用一个无效的x
此处 r 就是一个悬垂指针,它引用了提前被释放的变量 x,可以预料到,这段代码会报错:
error[E0597]: `x` does not live long enough // `x` 活得不够久
--> src/main.rs:7:17
|
7 | r = &x;
| ^^ borrowed value does not live long enough // 被借用的 `x` 活得不够久
8 | }
| - `x` dropped here while still borrowed // `x` 在这里被丢弃,但是它依然还在被借用
9 |
10 | println!("r: {}", r);
| - borrow later used here // 对 `x` 的借用在此处被使用
在这里 r 拥有更大的作用域,或者说活得更久。如果 Rust 不阻止该悬垂引用的发生,那么当 x 被释放后,r 所引用的值就不再是合法的,会导致我们程序发生异常行为,且该异常行为有时候会很难被发现。
借用检查
为了保证 Rust 的所有权和借用的正确性,Rust 使用了一个借用检查器(Borrow checker),来检查我们程序的借用正确性:
#![allow(unused)] fn main() { { let r; // ---------+-- 'a // | { // | let x = 5; // -+-- 'b | r = &x; // | | } // -+ | // | println!("r: {}", r); // | } // ---------+ }
这段代码和之前的一模一样,唯一的区别在于增加了对变量生命周期的注释。这里,r 变量被赋予了生命周期 'a,x 被赋予了生命周期 'b,从图示上可以明显看出生命周期 'b 比 'a 小很多。
在编译期,Rust 会比较两个变量的生命周期,结果发现 r 明明拥有生命周期 'a,但是却引用了一个小得多的生命周期 'b,在这种情况下,编译器会认为我们的程序存在风险,因此拒绝运行。
如果想要编译通过,也很简单,只要 'b 比 'a 大就好。总之,x 变量只要比 r 活得久,那么 r 就能随意引用 x 且不会存在危险:
#![allow(unused)] fn main() { { let x = 5; // ----------+-- 'b // | let r = &x; // --+-- 'a | // | | println!("r: {}", r); // | | // --+ | } // ----------+ }
根据之前的结论,我们重新实现了代码,现在 x 的生命周期 'b 大于 r 的生命周期 'a,因此 r 对 x 的引用是安全的。
通过之前的内容,我们了解了何为生命周期,也了解了 Rust 如何利用生命周期来确保引用是合法的,下面来看看函数中的生命周期。
函数中的生命周期
先来考虑一个例子 - 返回两个字符串切片中较长的那个,该函数的参数是两个字符串切片,返回值也是字符串切片:
fn main() { let string1 = String::from("abcd"); let string2 = "xyz"; let result = longest(string1.as_str(), string2); println!("The longest string is {}", result); }
#![allow(unused)] fn main() { fn longest(x: &str, y: &str) -> &str { if x.len() > y.len() { x } else { y } } }
这段 longest 实现,非常标准优美,就连多余的 return 和分号都没有,可是现实总是给我们重重一击:
error[E0106]: missing lifetime specifier
--> src/main.rs:9:33
|
9 | fn longest(x: &str, y: &str) -> &str {
| ---- ---- ^ expected named lifetime parameter // 参数需要一个生命周期
|
= help: this function's return type contains a borrowed value, but the signature does not say whether it is
borrowed from `x` or `y`
= 帮助: 该函数的返回值是一个引用类型,但是函数签名无法说明,该引用是借用自 `x` 还是 `y`
help: consider introducing a named lifetime parameter // 考虑引入一个生命周期
|
9 | fn longest<'a>(x: &'a str, y: &'a str) -> &'a str {
| ^^^^ ^^^^^^^ ^^^^^^^ ^^^
喔,这真是一个复杂的提示,那感觉就好像是生命周期去非诚勿扰相亲,结果在初印象环节就 23 盏灯全灭。等等,先别急,如果你仔细阅读,就会发现,其实主要是编译器无法知道该函数的返回值到底引用 x 还是 y ,因为编译器需要知道这些,来确保函数调用后的引用生命周期分析。
不过说来尴尬,就这个函数而言,我们也不知道返回值到底引用哪个,因为一个分支返回 x,另一个分支返回 y...这可咋办?先来分析下。
我们在定义该函数时,首先无法知道传递给函数的具体值,因此到底是 if 还是 else 被执行,无从得知。其次,传入引用的具体生命周期也无法知道,因此也不能像之前的例子那样通过分析生命周期来确定引用是否有效。同时,编译器的借用检查也无法推导出返回值的生命周期,因为它不知道 x 和 y 的生命周期跟返回值的生命周期之间的关系是怎样的(说实话,人都搞不清,何况编译器这个大聪明)。
因此,这时就回到了文章开头说的内容:在存在多个引用时,编译器有时会无法自动推导生命周期,此时就需要我们手动去标注,通过为参数标注合适的生命周期来帮助编译器进行借用检查的分析。
生命周期标注语法
生命周期标注并不会改变任何引用的实际作用域 -- 鲁迅
鲁迅说过的话,总是值得重点标注,当你未来更加理解生命周期时,你才会发现这句话的精髓和重要!现在先简单记住,标记的生命周期只是为了取悦编译器,让编译器不要难为我们,记住了吗?没记住,再回头看一遍,这对未来你遇到生命周期问题时会有很大的帮助!
在很多时候编译器是很聪明的,但是总有些时候,它会化身大聪明,自以为什么都很懂,然后去拒绝我们代码的执行,此时,就需要我们通过生命周期标注来告诉这个大聪明:别自作聪明了,听我的就好。
例如一个变量,只能活一个花括号,那么就算你给它标注一个活全局的生命周期,它还是会在前面的花括号结束处被释放掉,并不会真的全局存活。
生命周期的语法也颇为与众不同,以 ' 开头,名称往往是一个单独的小写字母,大多数人都用 'a 来作为生命周期的名称。 如果是引用类型的参数,那么生命周期会位于引用符号 & 之后,并用一个空格来将生命周期和引用参数分隔开:
#![allow(unused)] fn main() { &i32 // 一个引用 &'a i32 // 具有显式生命周期的引用 &'a mut i32 // 具有显式生命周期的可变引用 }
一个生命周期标注,它自身并不具有什么意义,因为生命周期的作用就是告诉编译器多个引用之间的关系。例如,有一个函数,它的第一个参数 first 是一个指向 i32 类型的引用,具有生命周期 'a,该函数还有另一个参数 second,它也是指向 i32 类型的引用,并且同样具有生命周期 'a。此处生命周期标注仅仅说明,这两个参数 first 和 second 至少活得和'a 一样久,至于到底活多久或者哪个活得更久,抱歉我们都无法得知:
#![allow(unused)] fn main() { fn useless<'a>(first: &'a i32, second: &'a i32) {} }
函数签名中的生命周期标注
继续之前的 longest 函数,从两个字符串切片中返回较长的那个:
#![allow(unused)] fn main() { fn longest<'a>(x: &'a str, y: &'a str) -> &'a str { if x.len() > y.len() { x } else { y } } }
需要注意的点如下:
- 和泛型一样,使用生命周期参数,需要先声明
<'a> x、y和返回值至少活得和'a一样久(因为返回值要么是x,要么是y)
该函数签名表明对于某些生命周期 'a,函数的两个参数都至少跟 'a 活得一样久,同时函数的返回引用也至少跟 'a 活得一样久。实际上,这意味着返回值的生命周期与参数生命周期中的较小值一致:虽然两个参数的生命周期都是标注了 'a,但是实际上这两个参数的真实生命周期可能是不一样的(生命周期 'a 不代表生命周期等于 'a,而是大于等于 'a)。
回忆下“鲁迅”说的话,再参考上面的内容,可以得出:在通过函数签名指定生命周期参数时,我们并没有改变传入引用或者返回引用的真实生命周期,而是告诉编译器当不满足此约束条件时,就拒绝编译通过。
因此 longest 函数并不知道 x 和 y 具体会活多久,只要知道它们的作用域至少能持续 'a 这么长就行。
当把具体的引用传给 longest 时,那生命周期 'a 的大小就是 x 和 y 的作用域的重合部分,换句话说,'a 的大小将等于 x 和 y 中较小的那个。由于返回值的生命周期也被标记为 'a,因此返回值的生命周期也是 x 和 y 中作用域较小的那个。
说实话,这段文字我写的都快崩溃了,不知道你们读起来如何,实在***太绕了。。那就干脆用一个例子来解释吧:
fn main() { let string1 = String::from("long string is long"); { let string2 = String::from("xyz"); let result = longest(string1.as_str(), string2.as_str()); println!("The longest string is {}", result); } }
在上例中,string1 的作用域直到 main 函数的结束,而 string2 的作用域到内部花括号的结束 },那么根据之前的理论,'a 是两者中作用域较小的那个,也就是 'a 的生命周期等于 string2 的生命周期,同理,由于函数返回的生命周期也是 'a,可以得出函数返回的生命周期也等于 string2 的生命周期。
现在来验证下上面的结论:result 的生命周期等于参数中生命周期最小的,因此要等于 string2 的生命周期,也就是说,result 要活得和 string2 一样久,观察下代码的实现,可以发现这个结论是正确的!
因此,在这种情况下,通过生命周期标注,编译器得出了和我们肉眼观察一样的结论,而不再是一个蒙圈的大聪明。
再来看一个例子,该例子证明了 result 的生命周期必须等于两个参数中生命周期较小的那个:
fn main() { let string1 = String::from("long string is long"); let result; { let string2 = String::from("xyz"); result = longest(string1.as_str(), string2.as_str()); } println!("The longest string is {}", result); }
Bang,错误冒头了:
error[E0597]: `string2` does not live long enough
--> src/main.rs:6:44
|
6 | result = longest(string1.as_str(), string2.as_str());
| ^^^^^^^ borrowed value does not live long enough
7 | }
| - `string2` dropped here while still borrowed
8 | println!("The longest string is {}", result);
| ------ borrow later used here
在上述代码中,result 必须要活到 println!处,因为 result 的生命周期是 'a,因此 'a 必须持续到 println!。
在 longest 函数中,string2 的生命周期也是 'a,由此说明 string2 也必须活到 println! 处,可是 string2 在代码中实际上只能活到内部语句块的花括号处 },小于它应该具备的生命周期 'a,因此编译出错。
作为人类,我们可以很清晰的看出 result 实际上引用了 string1,因为 string1 的长度明显要比 string2 长,既然如此,编译器不该如此矫情才对,它应该能认识到 result 没有引用 string2,让我们这段代码通过。只能说,作为尊贵的人类,编译器的发明者,你高估了这个工具的能力,它真的做不到!而且 Rust 编译器在调教上是非常保守的:当可能出错也可能不出错时,它会选择前者,抛出编译错误。
总之,显式的使用生命周期,可以让编译器正确的认识到多个引用之间的关系,最终帮我们提前规避可能存在的代码风险。
小练习:尝试着去更改 longest 函数,例如修改参数、生命周期或者返回值,然后推测结果如何,最后再跟编译器的输出进行印证。
深入思考生命周期标注
使用生命周期的方式往往取决于函数的功能,例如之前的 longest 函数,如果它永远只返回第一个参数 x,生命周期的标注该如何修改(该例子就是上面的小练习结果之一)?
#![allow(unused)] fn main() { fn longest<'a>(x: &'a str, y: &str) -> &'a str { x } }
在此例中,y 完全没有被使用,因此 y 的生命周期与 x 和返回值的生命周期没有任何关系,意味着我们也不必再为 y 标注生命周期,只需要标注 x 参数和返回值即可。
函数的返回值如果是一个引用类型,那么它的生命周期只会来源于:
- 函数参数的生命周期
- 函数体中某个新建引用的生命周期
若是后者情况,就是典型的悬垂引用场景:
#![allow(unused)] fn main() { fn longest<'a>(x: &str, y: &str) -> &'a str { let result = String::from("really long string"); result.as_str() } }
上面的函数的返回值就和参数 x,y 没有任何关系,而是引用了函数体内创建的字符串,那么很显然,该函数会报错:
error[E0515]: cannot return value referencing local variable `result` // 返回值result引用了本地的变量
--> src/main.rs:11:5
|
11 | result.as_str()
| ------^^^^^^^^^
| |
| returns a value referencing data owned by the current function
| `result` is borrowed here
主要问题就在于,result 在函数结束后就被释放,但是在函数结束后,对 result 的引用依然在继续。在这种情况下,没有办法指定合适的生命周期来让编译通过,因此我们也就在 Rust 中避免了悬垂引用。
那遇到这种情况该怎么办?最好的办法就是返回内部字符串的所有权,然后把字符串的所有权转移给调用者:
fn longest<'a>(_x: &str, _y: &str) -> String { String::from("really long string") } fn main() { let s = longest("not", "important"); }
至此,可以对生命周期进行下总结:生命周期语法用来将函数的多个引用参数和返回值的作用域关联到一起,一旦关联到一起后,Rust 就拥有充分的信息来确保我们的操作是内存安全的。
结构体中的生命周期
不仅仅函数具有生命周期,结构体其实也有这个概念,只不过我们之前对结构体的使用都停留在非引用类型字段上。细心的同学应该能回想起来,之前为什么不在结构体中使用字符串字面量或者字符串切片,而是统一使用 String 类型?原因很简单,后者在结构体初始化时,只要转移所有权即可,而前者,抱歉,它们是引用,它们不能为所欲为。
既然之前已经理解了生命周期,那么意味着在结构体中使用引用也变得可能:只要为结构体中的每一个引用标注上生命周期即可:
struct ImportantExcerpt<'a> { part: &'a str, } fn main() { let novel = String::from("Call me Ishmael. Some years ago..."); let first_sentence = novel.split('.').next().expect("Could not find a '.'"); let i = ImportantExcerpt { part: first_sentence, }; }
ImportantExcerpt 结构体中有一个引用类型的字段 part,因此需要为它标注上生命周期。结构体的生命周期标注语法跟泛型参数语法很像,需要对生命周期参数进行声明 <'a>。该生命周期标注说明,结构体 ImportantExcerpt 所引用的字符串 str 必须比该结构体活得更久。
从 main 函数实现来看,ImportantExcerpt 的生命周期从第 4 行开始,到 main 函数末尾结束,而该结构体引用的字符串从第一行开始,也是到 main 函数末尾结束,可以得出结论结构体引用的字符串活得比结构体久,这符合了编译器对生命周期的要求,因此编译通过。
与之相反,下面的代码就无法通过编译:
#[derive(Debug)] struct ImportantExcerpt<'a> { part: &'a str, } fn main() { let i; { let novel = String::from("Call me Ishmael. Some years ago..."); let first_sentence = novel.split('.').next().expect("Could not find a '.'"); i = ImportantExcerpt { part: first_sentence, }; } println!("{:?}",i); }
观察代码,可以看出结构体比它引用的字符串活得更久,引用字符串在内部语句块末尾 } 被释放后,println! 依然在外面使用了该结构体,因此会导致无效的引用,不出所料,编译报错:
error[E0597]: `novel` does not live long enough
--> src/main.rs:10:30
|
10 | let first_sentence = novel.split('.').next().expect("Could not find a '.'");
| ^^^^^^^^^^^^^^^^ borrowed value does not live long enough
...
14 | }
| - `novel` dropped here while still borrowed
15 | println!("{:?}",i);
| - borrow later used here
生命周期消除
实际上,对于编译器来说,每一个引用类型都有一个生命周期,那么为什么我们在使用过程中,很多时候无需标注生命周期?例如:
#![allow(unused)] fn main() { fn first_word(s: &str) -> &str { let bytes = s.as_bytes(); for (i, &item) in bytes.iter().enumerate() { if item == b' ' { return &s[0..i]; } } &s[..] } }
该函数的参数和返回值都是引用类型,尽管我们没有显式的为其标注生命周期,编译依然可以通过。其实原因不复杂,编译器为了简化用户的使用,运用了生命周期消除大法。
对于 first_word 函数,它的返回值是一个引用类型,那么该引用只有两种情况:
- 从参数获取
- 从函数体内部新创建的变量获取
如果是后者,就会出现悬垂引用,最终被编译器拒绝,因此只剩一种情况:返回值的引用是获取自参数,这就意味着参数和返回值的生命周期是一样的。道理很简单,我们能看出来,编译器自然也能看出来,因此,就算我们不标注生命周期,也不会产生歧义。
实际上,在 Rust 1.0 版本之前,这种代码果断不给通过,因为 Rust 要求必须显式的为所有引用标注生命周期:
#![allow(unused)] fn main() { fn first_word<'a>(s: &'a str) -> &'a str { }
在写了大量的类似代码后,Rust 社区抱怨声四起,包括开发者自己都忍不了了,最终揭锅而起,这才有了我们今日的幸福。
生命周期消除的规则不是一蹴而就,而是伴随着 总结-改善 流程的周而复始,一步一步走到今天,这也意味着,该规则以后可能也会进一步增加,我们需要手动标注生命周期的时候也会越来越少,hooray!
在开始之前有几点需要注意:
- 消除规则不是万能的,若编译器不能确定某件事是正确时,会直接判为不正确,那么你还是需要手动标注生命周期
- 函数或者方法中,参数的生命周期被称为
输入生命周期,返回值的生命周期被称为输出生命周期
三条消除规则
编译器使用三条消除规则来确定哪些场景不需要显式地去标注生命周期。其中第一条规则应用在输入生命周期上,第二、三条应用在输出生命周期上。若编译器发现三条规则都不适用时,就会报错,提示你需要手动标注生命周期。
-
每一个引用参数都会获得独自的生命周期
例如一个引用参数的函数就有一个生命周期标注:
fn foo<'a>(x: &'a i32),两个引用参数的有两个生命周期标注:fn foo<'a, 'b>(x: &'a i32, y: &'b i32), 依此类推。 -
若只有一个输入生命周期(函数参数中只有一个引用类型),那么该生命周期会被赋给所有的输出生命周期,也就是所有返回值的生命周期都等于该输入生命周期
例如函数
fn foo(x: &i32) -> &i32,x参数的生命周期会被自动赋给返回值&i32,因此该函数等同于fn foo<'a>(x: &'a i32) -> &'a i32 -
若存在多个输入生命周期,且其中一个是
&self或&mut self,则&self的生命周期被赋给所有的输出生命周期拥有
&self形式的参数,说明该函数是一个方法,该规则让方法的使用便利度大幅提升。
规则其实很好理解,但是,爱思考的读者肯定要发问了,例如第三条规则,若一个方法,它的返回值的生命周期就是跟参数 &self 的不一样怎么办?总不能强迫我返回的值总是和 &self 活得一样久吧?! 问得好,答案很简单:手动标注生命周期,因为这些规则只是编译器发现你没有标注生命周期时默认去使用的,当你标注生命周期后,编译器自然会乖乖听你的话。
让我们假装自己是编译器,然后看下以下的函数该如何应用这些规则:
例子 1
#![allow(unused)] fn main() { fn first_word(s: &str) -> &str { // 实际项目中的手写代码 }
首先,我们手写的代码如上所示时,编译器会先应用第一条规则,为每个参数标注一个生命周期:
#![allow(unused)] fn main() { fn first_word<'a>(s: &'a str) -> &str { // 编译器自动为参数添加生命周期 }
此时,第二条规则就可以进行应用,因为函数只有一个输入生命周期,因此该生命周期会被赋予所有的输出生命周期:
#![allow(unused)] fn main() { fn first_word<'a>(s: &'a str) -> &'a str { // 编译器自动为返回值添加生命周期 }
此时,编译器为函数签名中的所有引用都自动添加了具体的生命周期,因此编译通过,且用户无需手动去标注生命周期,只要按照 fn first_word(s: &str) -> &str { 的形式写代码即可。
例子 2 再来看一个例子:
#![allow(unused)] fn main() { fn longest(x: &str, y: &str) -> &str { // 实际项目中的手写代码 }
首先,编译器会应用第一条规则,为每个参数都标注生命周期:
#![allow(unused)] fn main() { fn longest<'a, 'b>(x: &'a str, y: &'b str) -> &str { }
但是此时,第二条规则却无法被使用,因为输入生命周期有两个,第三条规则也不符合,因为它是函数,不是方法,因此没有 &self 参数。在套用所有规则后,编译器依然无法为返回值标注合适的生命周期,因此,编译器就会报错,提示我们需要手动标注生命周期:
error[E0106]: missing lifetime specifier
--> src/main.rs:1:47
|
1 | fn longest<'a, 'b>(x: &'a str, y: &'b str) -> &str {
| ------- ------- ^ expected named lifetime parameter
|
= help: this function's return type contains a borrowed value, but the signature does not say whether it is borrowed from `x` or `y`
note: these named lifetimes are available to use
--> src/main.rs:1:12
|
1 | fn longest<'a, 'b>(x: &'a str, y: &'b str) -> &str {
| ^^ ^^
help: consider using one of the available lifetimes here
|
1 | fn longest<'a, 'b>(x: &'a str, y: &'b str) -> &'lifetime str {
| +++++++++
不得不说,Rust 编译器真的很强大,还贴心的给我们提示了该如何修改,虽然。。。好像。。。。它的提示貌似不太准确。这里我们更希望参数和返回值都是 'a 生命周期。
方法中的生命周期
先来回忆下泛型的语法:
#![allow(unused)] fn main() { struct Point<T> { x: T, y: T, } impl<T> Point<T> { fn x(&self) -> &T { &self.x } } }
实际上,为具有生命周期的结构体实现方法时,我们使用的语法跟泛型参数语法很相似:
#![allow(unused)] fn main() { struct ImportantExcerpt<'a> { part: &'a str, } impl<'a> ImportantExcerpt<'a> { fn level(&self) -> i32 { 3 } } }
其中有几点需要注意的:
impl中必须使用结构体的完整名称,包括<'a>,因为生命周期标注也是结构体类型的一部分!- 方法签名中,往往不需要标注生命周期,得益于生命周期消除的第一和第三规则
下面的例子展示了第三规则应用的场景:
#![allow(unused)] fn main() { impl<'a> ImportantExcerpt<'a> { fn announce_and_return_part(&self, announcement: &str) -> &str { println!("Attention please: {}", announcement); self.part } } }
首先,编译器应用第一规则,给予每个输入参数一个生命周期:
#![allow(unused)] fn main() { impl<'a> ImportantExcerpt<'a> { fn announce_and_return_part<'b>(&'a self, announcement: &'b str) -> &str { println!("Attention please: {}", announcement); self.part } } }
需要注意的是,编译器不知道 announcement 的生命周期到底多长,因此它无法简单的给予它生命周期 'a,而是重新声明了一个全新的生命周期 'b。
接着,编译器应用第三规则,将 &self 的生命周期赋给返回值 &str:
#![allow(unused)] fn main() { impl<'a> ImportantExcerpt<'a> { fn announce_and_return_part<'b>(&'a self, announcement: &'b str) -> &'a str { println!("Attention please: {}", announcement); self.part } } }
Bingo,最开始的代码,尽管我们没有给方法标注生命周期,但是在第一和第三规则的配合下,编译器依然完美的为我们亮起了绿灯。
在结束这块儿内容之前,再来做一个有趣的修改,将方法返回的生命周期改为'b:
#![allow(unused)] fn main() { impl<'a> ImportantExcerpt<'a> { fn announce_and_return_part<'b>(&'a self, announcement: &'b str) -> &'b str { println!("Attention please: {}", announcement); self.part } } }
此时,编译器会报错,因为编译器无法知道 'a 和 'b 的关系。 &self 生命周期是 'a,那么 self.part 的生命周期也是 'a,但是好巧不巧的是,我们手动为返回值 self.part 标注了生命周期 'b,因此编译器需要知道 'a 和 'b 的关系。
有一点很容易推理出来:由于 &'a self 是被引用的一方,因此引用它的 &'b str 必须要活得比它短,否则会出现悬垂引用。因此说明生命周期 'b 必须要比 'a 小,只要满足了这一点,编译器就不会再报错:
#![allow(unused)] fn main() { impl<'a: 'b, 'b> ImportantExcerpt<'a> { fn announce_and_return_part(&'a self, announcement: &'b str) -> &'b str { println!("Attention please: {}", announcement); self.part } } }
Bang,一个复杂的玩意儿被甩到了你面前,就问怕不怕?
就关键点稍微解释下:
'a: 'b,是生命周期约束语法,跟泛型约束非常相似,用于说明'a必须比'b活得久- 可以把
'a和'b都在同一个地方声明(如上),或者分开声明但通过where 'a: 'b约束生命周期关系,如下:
#![allow(unused)] fn main() { impl<'a> ImportantExcerpt<'a> { fn announce_and_return_part<'b>(&'a self, announcement: &'b str) -> &'b str where 'a: 'b, { println!("Attention please: {}", announcement); self.part } } }
总之,实现方法比想象中简单:加一个约束,就能暗示编译器,尽管引用吧,反正我想引用的内容比我活得久,爱咋咋地,我怎么都不会引用到无效的内容!
静态生命周期
在 Rust 中有一个非常特殊的生命周期,那就是 'static,拥有该生命周期的引用可以和整个程序活得一样久。
在之前我们学过字符串字面量,提到过它是被硬编码进 Rust 的二进制文件中,因此这些字符串变量全部具有 'static 的生命周期:
#![allow(unused)] fn main() { let s: &'static str = "我没啥优点,就是活得久,嘿嘿"; }
这时候,有些聪明的小脑瓜就开始开动了:当生命周期不知道怎么标时,对类型施加一个静态生命周期的约束 T: 'static 是不是很爽?这样我和编译器再也不用操心它到底活多久了。
嗯,只能说,这个想法是对的,在不少情况下,'static 约束确实可以解决生命周期编译不通过的问题,但是问题来了:本来该引用没有活那么久,但是你非要说它活那么久,万一引入了潜在的 BUG 怎么办?
因此,遇到因为生命周期导致的编译不通过问题,首先想的应该是:是否是我们试图创建一个悬垂引用,或者是试图匹配不一致的生命周期,而不是简单粗暴的用 'static 来解决问题。
但是,话说回来,存在即合理,有时候,'static 确实可以帮助我们解决非常复杂的生命周期问题甚至是无法被手动解决的生命周期问题,那么此时就应该放心大胆的用,只要你确定:你的所有引用的生命周期都是正确的,只是编译器太笨不懂罢了。
总结下:
- 生命周期
'static意味着能和程序活得一样久,例如字符串字面量和特征对象 - 实在遇到解决不了的生命周期标注问题,可以尝试
T: 'static,有时候它会给你奇迹
事实上,关于
'static, 有两种用法:&'static和T: 'static
一个复杂例子: 泛型、特征约束
手指已经疲软无力,我好想停止,但是华丽的开场都要有与之匹配的谢幕,那我们就用一个稍微复杂点的例子来结束:
#![allow(unused)] fn main() { use std::fmt::Display; fn longest_with_an_announcement<'a, T>( x: &'a str, y: &'a str, ann: T, ) -> &'a str where T: Display, { println!("Announcement! {}", ann); if x.len() > y.len() { x } else { y } } }
依然是熟悉的配方 longest,但是多了一段废话: ann,因为要用格式化 {} 来输出 ann,因此需要它实现 Display 特征。
返回值和错误处理
Rust 中的错误主要分为两类:
- 可恢复错误,通常用于从系统全局角度来看可以接受的错误,例如处理用户的访问、操作等错误,这些错误只会影响某个用户自身的操作进程,而不会对系统的全局稳定性产生影响
- 不可恢复错误,刚好相反,该错误通常是全局性或者系统性的错误,例如数组越界访问,系统启动时发生了影响启动流程的错误等等,这些错误的影响往往对于系统来说是致命的
很多编程语言,并不会区分这些错误,而是直接采用异常的方式去处理。Rust 没有异常,但是 Rust 也有自己的卧龙凤雏:Result<T, E> 用于可恢复错误,panic! 用于不可恢复错误
panic 深入剖析
在正式开始之前,先来思考一个问题:假设我们想要从文件读取数据,如果失败,你有没有好的办法通知调用者为何失败?如果成功,你有没有好的办法把读取的结果返还给调用者?
panic! 与不可恢复错误
上面的问题在真实场景会经常遇到,其实处理起来挺复杂的,让我们先做一个假设:文件读取操作发生在系统启动阶段。那么可以轻易得出一个结论,一旦文件读取失败,那么系统启动也将失败,这意味着该失败是不可恢复的错误,无论是因为文件不存在还是操作系统硬盘的问题,这些只是错误的原因不同,但是归根到底都是不可恢复的错误(梳理清楚当前场景的错误类型非常重要)。
对于这些严重到影响程序运行的错误,触发 panic 是很好的解决方式。在 Rust 中触发 panic 有两种方式:被动触发和主动调用,下面依次来看看。
被动触发
先来看一段简单又熟悉的代码:
fn main() { let v = vec![1, 2, 3]; v[99]; }
心明眼亮的同学立马就能看出这里发生了严重的错误 —— 数组访问越界,在其它编程语言中无一例外,都会报出严重的异常,甚至导致程序直接崩溃关闭。
而 Rust 也不例外,运行后将看到如下报错:
$ cargo run
Compiling panic v0.1.0 (file:///projects/panic)
Finished dev [unoptimized + debuginfo] target(s) in 0.27s
Running `target/debug/panic`
thread 'main' panicked at 'index out of bounds: the len is 3 but the index is 99', src/main.rs:4:5
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
上面给出了非常详细的报错信息,包含了具体的异常描述以及发生的位置,甚至你还可以加入额外的命令来看到异常发生时的堆栈信息,这个会在后面详细展开。
总之,类似的 panic 还有很多,而被动触发的 panic 是我们日常开发中最常遇到的,这也是 Rust 给我们的一种保护,毕竟错误只有抛出来,才有可能被处理,否则只会偷偷隐藏起来,寻觅时机给你致命一击。
主动调用
在某些特殊场景中,开发者想要主动抛出一个异常,例如开头提到的在系统启动阶段读取文件失败。
对此,Rust 为我们提供了 panic! 宏,当调用执行该宏时,程序会打印出一个错误信息,展开报错点往前的函数调用堆栈,最后退出程序。
切记,一定是不可恢复的错误,才调用
panic!处理,你总不想系统仅仅因为用户随便传入一个非法参数就崩溃吧?所以,只有当你不知道该如何处理时,再去调用 panic!.
首先,来调用一下 panic!,这里使用了最简单的代码实现,实际上你在程序的任何地方都可以这样调用:
fn main() { panic!("crash and burn"); }
运行后输出:
thread 'main' panicked at 'crash and burn', src/main.rs:2:5
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
以上信息包含了两条重要信息:
main函数所在的线程崩溃了,发生的代码位置是src/main.rs中的第 2 行第 5 个字符(去除该行前面的空字符)- 在使用时加上一个环境变量可以获取更详细的栈展开信息:
- Linux/macOS 等 UNIX 系统:
RUST_BACKTRACE=1 cargo run - Windows 系统(PowerShell):
$env:RUST_BACKTRACE=1 ; cargo run
- Linux/macOS 等 UNIX 系统:
下面让我们针对第二点进行详细展开讲解。
backtrace 栈展开
在真实场景中,错误往往涉及到很长的调用链甚至会深入第三方库,如果没有栈展开技术,错误将难以跟踪处理,下面我们来看一个真实的崩溃例子:
fn main() { let v = vec![1, 2, 3]; v[99]; }
上面的代码很简单,数组只有 3 个元素,我们却尝试去访问它的第 100 号元素(数组索引从 0 开始),那自然会崩溃。
我们的读者里不乏正义之士,此时肯定要质疑,一个简单的数组越界访问,为何要直接让程序崩溃?是不是有些小题大作了?
如果有过 C 语言的经验,即使你越界了,问题不大,我依然尝试去访问,至于这个值是不是你想要的(100 号内存地址也有可能有值,只不过是其它变量或者程序的!),抱歉,不归我管,我只负责取,你要负责管理好自己的索引访问范围。上面这种情况被称为缓冲区溢出,并可能会导致安全漏洞,例如攻击者可以通过索引来访问到数组后面不被允许的数据。
说实话,我宁愿程序崩溃,为什么?当你取到了一个不属于你的值,这在很多时候会导致程序上的逻辑 BUG! 有编程经验的人都知道这种逻辑上的 BUG 是多么难被发现和修复!因此程序直接崩溃,然后告诉我们问题发生的位置,最后我们对此进行修复,这才是最合理的软件开发流程,而不是把问题藏着掖着:
thread 'main' panicked at 'index out of bounds: the len is 3 but the index is 99', src/main.rs:4:5
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
好的,现在成功知道问题发生的位置,但是如果我们想知道该问题之前经过了哪些调用环节,该怎么办?那就按照提示使用 RUST_BACKTRACE=1 cargo run 或 $env:RUST_BACKTRACE=1 ; cargo run 来再一次运行程序:
thread 'main' panicked at 'index out of bounds: the len is 3 but the index is 99', src/main.rs:4:5
stack backtrace:
0: rust_begin_unwind
at /rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library/std/src/panicking.rs:517:5
1: core::panicking::panic_fmt
at /rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library/core/src/panicking.rs:101:14
2: core::panicking::panic_bounds_check
at /rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library/core/src/panicking.rs:77:5
3: <usize as core::slice::index::SliceIndex<[T]>>::index
at /rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library/core/src/slice/index.rs:184:10
4: core::slice::index::<impl core::ops::index::Index<I> for [T]>::index
at /rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library/core/src/slice/index.rs:15:9
5: <alloc::vec::Vec<T,A> as core::ops::index::Index<I>>::index
at /rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library/alloc/src/vec/mod.rs:2465:9
6: world_hello::main
at ./src/main.rs:4:5
7: core::ops::function::FnOnce::call_once
at /rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library/core/src/ops/function.rs:227:5
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.
上面的代码就是一次栈展开(也称栈回溯),它包含了函数调用的顺序,当然按照逆序排列:最近调用的函数排在列表的最上方。因为咱们的 main 函数基本是最先调用的函数了,所以排在了倒数第二位,还有一个关注点,排在最顶部最后一个调用的函数是 rust_begin_unwind,该函数的目的就是进行栈展开,呈现这些列表信息给我们。
要获取到栈回溯信息,你还需要开启 debug 标志,该标志在使用 cargo run 或者 cargo build 时自动开启(这两个操作默认是 Debug 运行方式)。同时,栈展开信息在不同操作系统或者 Rust 版本上也有所不同。
panic 时的两种终止方式
当出现 panic! 时,程序提供了两种方式来处理终止流程:栈展开和直接终止。
其中,默认的方式就是 栈展开,这意味着 Rust 会回溯栈上数据和函数调用,因此也意味着更多的善后工作,好处是可以给出充分的报错信息和栈调用信息,便于事后的问题复盘。直接终止,顾名思义,不清理数据就直接退出程序,善后工作交与操作系统来负责。
对于绝大多数用户,使用默认选择是最好的,但是当你关心最终编译出的二进制可执行文件大小时,那么可以尝试去使用直接终止的方式,例如下面的配置修改 Cargo.toml 文件,实现在release模式下遇到 panic 直接终止:
#![allow(unused)] fn main() { [profile.release] panic = 'abort' }
线程 panic 后,程序是否会终止?
长话短说,如果是 main 线程,则程序会终止,如果是其它子线程,该线程会终止,但是不会影响 main 线程。因此,尽量不要在 main 线程中做太多任务,将这些任务交由子线程去做,就算子线程 panic 也不会导致整个程序的结束。
何时该使用 panic!
下面让我们大概罗列下何时适合使用 panic,也许经过之前的学习,你已经能够对 panic 的使用有了自己的看法,但是我们还是会罗列一些常见的用法来加深你的理解。
先来一点背景知识,在前面章节我们粗略讲过 Result<T, E> 这个枚举类型,它是用来表示函数的返回结果:
#![allow(unused)] fn main() { enum Result<T, E> { Ok(T), Err(E), } }
当没有错误发生时,函数返回一个用 Result 类型包裹的值 Ok(T),当错误时,返回一个 Err(E)。对于 Result 返回我们有很多处理方法,最简单粗暴的就是 unwrap 和 expect,这两个函数非常类似,我们以 unwrap 举例:
#![allow(unused)] fn main() { use std::net::IpAddr; let home: IpAddr = "127.0.0.1".parse().unwrap(); }
上面的 parse 方法试图将字符串 "127.0.0.1" 解析为一个 IP 地址类型 IpAddr,它返回一个 Result<IpAddr, E> 类型,如果解析成功,则把 Ok(IpAddr) 中的值赋给 home,如果失败,则不处理 Err(E),而是直接 panic。
因此 unwrap 简而言之:成功则返回值,失败则 panic,总之不进行任何错误处理。
示例、原型、测试
这几个场景下,需要快速地搭建代码,错误处理会拖慢编码的速度,也不是特别有必要,因此通过 unwrap、expect 等方法来处理是最快的。
同时,当我们回头准备做错误处理时,可以全局搜索这些方法,不遗漏地进行替换。
你确切的知道你的程序是正确时,可以使用 panic
因为 panic 的触发方式比错误处理要简单,因此可以让代码更清晰,可读性也更加好,当我们的代码注定是正确时,你可以用 unwrap 等方法直接进行处理,反正也不可能 panic :
#![allow(unused)] fn main() { use std::net::IpAddr; let home: IpAddr = "127.0.0.1".parse().unwrap(); }
例如上面的例子,"127.0.0.1" 就是 ip 地址,因此我们知道 parse 方法一定会成功,那么就可以直接用 unwrap 方法进行处理。
当然,如果该字符串是来自于用户输入,那在实际项目中,就必须用错误处理的方式,而不是 unwrap,否则你的程序一天要崩溃几十万次吧!
可能导致全局有害状态时
有害状态大概分为几类:
- 非预期的错误
- 后续代码的运行会受到显著影响
- 内存安全的问题
当错误预期会出现时,返回一个错误较为合适,例如解析器接收到格式错误的数据,HTTP 请求接收到错误的参数甚至该请求内的任何错误(不会导致整个程序有问题,只影响该次请求)。因为错误是可预期的,因此也是可以处理的。
当启动时某个流程发生了错误,对后续代码的运行造成了影响,那么就应该使用 panic,而不是处理错误后继续运行,当然你可以通过重试的方式来继续。
上面提到过,数组访问越界,就要 panic 的原因,这个就是属于内存安全的范畴,一旦内存访问不安全,那么我们就无法保证自己的程序会怎么运行下去,也无法保证逻辑和数据的正确性。
panic 原理剖析
本来不想写这块儿内容,因为真的难写,但是转念一想,既然号称圣经,那么本书就得与众不同,避重就轻显然不是该有的态度。
当调用 panic! 宏时,它会
- 格式化
panic信息,然后使用该信息作为参数,调用std::panic::panic_any()函数 panic_any会检查应用是否使用了panic hook,如果使用了,该hook函数就会被调用(hook是一个钩子函数,是外部代码设置的,用于在panic触发时,执行外部代码所需的功能)- 当
hook函数返回后,当前的线程就开始进行栈展开:从panic_any开始,如果寄存器或者栈因为某些原因信息错乱了,那很可能该展开会发生异常,最终线程会直接停止,展开也无法继续进行 - 展开的过程是一帧一帧的去回溯整个栈,每个帧的数据都会随之被丢弃,但是在展开过程中,你可能会遇到被用户标记为
catching的帧(通过std::panic::catch_unwind()函数标记),此时用户提供的catch函数会被调用,展开也随之停止:当然,如果catch选择在内部调用std::panic::resume_unwind()函数,则展开还会继续。
还有一种情况,在展开过程中,如果展开本身 panic 了,那展开线程会终止,展开也随之停止。
一旦线程展开被终止或者完成,最终的输出结果是取决于哪个线程 panic:对于 main 线程,操作系统提供的终止功能 core::intrinsics::abort() 会被调用,最终结束当前的 panic 进程;如果是其它子线程,那么子线程就会简单的终止,同时信息会在稍后通过 std::thread::join() 进行收集。
可恢复的错误 Result
还记得上一节中,提到的关于文件读取的思考题吧?当时我们解决了读取文件时遇到不可恢复错误该怎么处理的问题,现在来看看,读取过程中,正常返回和遇到可以恢复的错误时该如何处理。
假设,我们有一台消息服务器,每个用户都通过 websocket 连接到该服务器来接收和发送消息,该过程就涉及到 socket 文件的读写,那么此时,如果一个用户的读写发生了错误,显然不能直接 panic,否则服务器会直接崩溃,所有用户都会断开连接,因此我们需要一种更温和的错误处理方式:Result<T, E>。
之前章节有提到过,Result<T, E> 是一个枚举类型,定义如下:
#![allow(unused)] fn main() { enum Result<T, E> { Ok(T), Err(E), } }
泛型参数 T 代表成功时存入的正确值的类型,存放方式是 Ok(T),E 代表错误时存入的错误值,存放方式是 Err(E),枯燥的讲解永远不及代码生动准确,因此先来看下打开文件的例子:
use std::fs::File; fn main() { let f = File::open("hello.txt"); }
以上 File::open 返回一个 Result 类型,那么问题来了:
如何获知变量类型或者函数的返回类型
有几种常用的方式,此处更推荐第二种方法:
- 第一种是查询标准库或者三方库文档,搜索
File,然后找到它的open方法- 推荐
VSCodeIDE 和rust-analyzer插件,如果你成功安装的话,那么就可以在VSCode中很方便的通过代码跳转的方式查看代码,同时rust-analyzer插件还会对代码中的类型进行标注,非常方便好用!- 你还可以尝试故意标记一个错误的类型,然后让编译器告诉你:
#![allow(unused)] fn main() { let f: u32 = File::open("hello.txt"); }
错误提示如下:
error[E0308]: mismatched types
--> src/main.rs:4:18
|
4 | let f: u32 = File::open("hello.txt");
| ^^^^^^^^^^^^^^^^^^^^^^^ expected u32, found enum
`std::result::Result`
|
= note: expected type `u32`
found type `std::result::Result<std::fs::File, std::io::Error>`
上面代码,故意将 f 类型标记成整形,编译器立刻不乐意了,你是在忽悠我吗?打开文件操作返回一个整形?来,大哥来告诉你返回什么:std::result::Result<std::fs::File, std::io::Error>,我的天呐,怎么这么长的类型!
别慌,其实很简单,首先 Result 本身是定义在 std::result 中的,但是因为 Result 很常用,所以就被包含在了prelude中(将常用的东东提前引入到当前作用域内),因此无需手动引入 std::result::Result,那么返回类型可以简化为 Result<std::fs::File,std::io::Error>,你看看是不是很像标准的 Result<T, E> 枚举定义?只不过 T 被替换成了具体的类型 std::fs::File,是一个文件句柄类型,E 被替换成 std::io::Error,是一个 IO 错误类型.
这个返回值类型说明 File::open 调用如果成功则返回一个可以进行读写的文件句柄,如果失败,则返回一个 IO 错误:文件不存在或者没有访问文件的权限等。总之 File::open 需要一个方式告知调用者是成功还是失败,并同时返回具体的文件句柄(成功)或错误信息(失败),万幸的是,这些信息可以通过 Result 枚举提供:
use std::fs::File; fn main() { let f = File::open("hello.txt"); let f = match f { Ok(file) => file, Err(error) => { panic!("Problem opening the file: {:?}", error) }, }; }
代码很清晰,对打开文件后的 Result<T, E> 类型进行匹配取值,如果是成功,则将 Ok(file) 中存放的的文件句柄 file 赋值给 f,如果失败,则将 Err(error) 中存放的错误信息 error 使用 panic 抛出来,进而结束程序,这非常符合上文提到过的 panic 使用场景。
好吧,也没有那么合理 :)
对返回的错误进行处理
直接 panic 还是过于粗暴,因为实际上 IO 的错误有很多种,我们需要对部分错误进行特殊处理,而不是所有错误都直接崩溃:
use std::fs::File; use std::io::ErrorKind; fn main() { let f = File::open("hello.txt"); let f = match f { Ok(file) => file, Err(error) => match error.kind() { ErrorKind::NotFound => match File::create("hello.txt") { Ok(fc) => fc, Err(e) => panic!("Problem creating the file: {:?}", e), }, other_error => panic!("Problem opening the file: {:?}", other_error), }, }; }
上面代码在匹配出 error 后,又对 error 进行了详细的匹配解析,最终结果:
- 如果是文件不存在错误
ErrorKind::NotFound,就创建文件,这里创建文件File::create也是返回Result,因此继续用match对其结果进行处理:创建成功,将新的文件句柄赋值给f,如果失败,则panic - 剩下的错误,一律
panic
失败就 panic: unwrap 和 expect
上一节中,已经看到过这两兄弟的简单介绍,这里再来回顾下。
在不需要处理错误的场景,例如写原型、示例时,我们不想使用 match 去匹配 Result<T, E> 以获取其中的 T 值,因为 match 的穷尽匹配特性,你总要去处理下 Err 分支。那么有没有办法简化这个过程?有,答案就是 unwrap 和 expect。
它们的作用就是,如果返回成功,就将 Ok(T) 中的值取出来,如果失败,就直接 panic,真的勇士绝不多 BB,直接崩溃。
use std::fs::File; fn main() { let f = File::open("hello.txt").unwrap(); }
如果调用这段代码时 hello.txt 文件不存在,那么 unwrap 就将直接 panic:
thread 'main' panicked at 'called `Result::unwrap()` on an `Err` value: Os { code: 2, kind: NotFound, message: "No such file or directory" }', src/main.rs:4:37
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
expect 跟 unwrap 很像,也是遇到错误直接 panic, 但是会带上自定义的错误提示信息,相当于重载了错误打印的函数:
use std::fs::File; fn main() { let f = File::open("hello.txt").expect("Failed to open hello.txt"); }
报错如下:
thread 'main' panicked at 'Failed to open hello.txt: Os { code: 2, kind: NotFound, message: "No such file or directory" }', src/main.rs:4:37
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
可以看出,expect 相比 unwrap 能提供更精确的错误信息,在有些场景也会更加实用。
传播错误
咱们的程序几乎不太可能只有 A->B 形式的函数调用,一个设计良好的程序,一个功能涉及十几层的函数调用都有可能。而错误处理也往往不是哪里调用出错,就在哪里处理,实际应用中,大概率会把错误层层上传然后交给调用链的上游函数进行处理,错误传播将极为常见。
例如以下函数从文件中读取用户名,然后将结果进行返回:
#![allow(unused)] fn main() { use std::fs::File; use std::io::{self, Read}; fn read_username_from_file() -> Result<String, io::Error> { // 打开文件,f是`Result<文件句柄,io::Error>` let f = File::open("hello.txt"); let mut f = match f { // 打开文件成功,将file句柄赋值给f Ok(file) => file, // 打开文件失败,将错误返回(向上传播) Err(e) => return Err(e), }; // 创建动态字符串s let mut s = String::new(); // 从f文件句柄读取数据并写入s中 match f.read_to_string(&mut s) { // 读取成功,返回Ok封装的字符串 Ok(_) => Ok(s), // 将错误向上传播 Err(e) => Err(e), } } }
有几点值得注意:
- 该函数返回一个
Result<String, io::Error>类型,当读取用户名成功时,返回Ok(String),失败时,返回Err(io:Error) File::open和f.read_to_string返回的Result<T, E>中的E就是io::Error
由此可见,该函数将 io::Error 的错误往上进行传播,该函数的调用者最终会对 Result<String,io::Error> 进行再处理,至于怎么处理就是调用者的事,如果是错误,它可以选择继续向上传播错误,也可以直接 panic,亦或将具体的错误原因包装后写入 socket 中呈现给终端用户。
但是上面的代码也有自己的问题,那就是太长了(优秀的程序员身上的优点极多,其中最大的优点就是懒),我自认为也有那么一点点优秀,因此见不得这么啰嗦的代码,下面咱们来讲讲如何简化它。
传播界的大明星: ?
大明星出场,必须得有排面,来看看 ? 的排面:
#![allow(unused)] fn main() { use std::fs::File; use std::io; use std::io::Read; fn read_username_from_file() -> Result<String, io::Error> { let mut f = File::open("hello.txt")?; let mut s = String::new(); f.read_to_string(&mut s)?; Ok(s) } }
看到没,这就是排面,相比前面的 match 处理错误的函数,代码直接减少了一半不止,但是,一山更比一山难,看不懂啊!
其实 ? 就是一个宏,它的作用跟上面的 match 几乎一模一样:
#![allow(unused)] fn main() { let mut f = match f { // 打开文件成功,将file句柄赋值给f Ok(file) => file, // 打开文件失败,将错误返回(向上传播) Err(e) => return Err(e), }; }
如果结果是 Ok(T),则把 T 赋值给 f,如果结果是 Err(E),则返回该错误,所以 ? 特别适合用来传播错误。
虽然 ? 和 match 功能一致,但是事实上 ? 会更胜一筹。何解?
想象一下,一个设计良好的系统中,肯定有自定义的错误特征,错误之间很可能会存在上下级关系,例如标准库中的 std::io::Error 和 std::error::Error,前者是 IO 相关的错误结构体,后者是一个最最通用的标准错误特征,同时前者实现了后者,因此 std::io::Error 可以转换为 std:error::Error。
明白了以上的错误转换,? 的更胜一筹就很好理解了,它可以自动进行类型提升(转换):
#![allow(unused)] fn main() { fn open_file() -> Result<File, Box<dyn std::error::Error>> { let mut f = File::open("hello.txt")?; Ok(f) } }
上面代码中 File::open 报错时返回的错误是 std::io::Error 类型,但是 open_file 函数返回的错误类型是 std::error::Error 的特征对象,可以看到一个错误类型通过 ? 返回后,变成了另一个错误类型,这就是 ? 的神奇之处。
根本原因是在于标准库中定义的 From 特征,该特征有一个方法 from,用于把一个类型转成另外一个类型,? 可以自动调用该方法,然后进行隐式类型转换。因此只要函数返回的错误 ReturnError 实现了 From<OtherError> 特征,那么 ? 就会自动把 OtherError 转换为 ReturnError。
这种转换非常好用,意味着你可以用一个大而全的 ReturnError 来覆盖所有错误类型,只需要为各种子错误类型实现这种转换即可。
强中自有强中手,一码更比一码短:
#![allow(unused)] fn main() { use std::fs::File; use std::io; use std::io::Read; fn read_username_from_file() -> Result<String, io::Error> { let mut s = String::new(); File::open("hello.txt")?.read_to_string(&mut s)?; Ok(s) } }
瞧见没? ? 还能实现链式调用,File::open 遇到错误就返回,没有错误就将 Ok 中的值取出来用于下一个方法调用,简直太精妙了,从 Go 语言过来的我,内心狂喜(其实学 Rust 的苦和痛我才不会告诉你们)。
不仅有更强,还要有最强,我不信还有人比我更短(不要误解):
#![allow(unused)] fn main() { use std::fs; use std::io; fn read_username_from_file() -> Result<String, io::Error> { // read_to_string是定义在std::io中的方法,因此需要在上面进行引用 fs::read_to_string("hello.txt") } }
从文件读取数据到字符串中,是比较常见的操作,因此 Rust 标准库为我们提供了 fs::read_to_string 函数,该函数内部会打开一个文件、创建 String、读取文件内容最后写入字符串并返回,因为该函数其实与本章讲的内容关系不大,因此放在最后来讲,其实只是我想震你们一下 :)
? 用于 Option 的返回
? 不仅仅可以用于 Result 的传播,还能用于 Option 的传播,再来回忆下 Option 的定义:
#![allow(unused)] fn main() { pub enum Option<T> { Some(T), None } }
Result 通过 ? 返回错误,那么 Option 就通过 ? 返回 None:
#![allow(unused)] fn main() { fn first(arr: &[i32]) -> Option<&i32> { let v = arr.get(0)?; Some(v) } }
上面的函数中,arr.get 返回一个 Option<&i32> 类型,因为 ? 的使用,如果 get 的结果是 None,则直接返回 None,如果是 Some(&i32),则把里面的值赋给 v。
其实这个函数有些画蛇添足,我们完全可以写出更简单的版本:
#![allow(unused)] fn main() { fn first(arr: &[i32]) -> Option<&i32> { arr.get(0) } }
有一句话怎么说?没有需求,制造需求也要上……大家别跟我学习,这是软件开发大忌。只能用代码洗洗眼了:
#![allow(unused)] fn main() { fn last_char_of_first_line(text: &str) -> Option<char> { text.lines().next()?.chars().last() } }
上面代码展示了在链式调用中使用 ? 提前返回 None 的用法, .next 方法返回的是 Option 类型:如果返回 Some(&str),那么继续调用 chars 方法,如果返回 None,则直接从整个函数中返回 None,不再继续进行链式调用。
新手用 ? 常会犯的错误
初学者在用 ? 时,老是会犯错,例如写出这样的代码:
#![allow(unused)] fn main() { fn first(arr: &[i32]) -> Option<&i32> { arr.get(0)? } }
这段代码无法通过编译,切记:? 操作符需要一个变量来承载正确的值,这个函数只会返回 Some(&i32) 或者 None,只有错误值能直接返回,正确的值不行,所以如果数组中存在 0 号元素,那么函数第二行使用 ? 后的返回类型为 &i32 而不是 Some(&i32)。因此 ? 只能用于以下形式:
let v = xxx()?;xxx()?.yyy()?;
带返回值的 main 函数
在了解了 ? 的使用限制后,这段代码你很容易看出它无法编译:
use std::fs::File; fn main() { let f = File::open("hello.txt")?; }
运行后会报错:
$ cargo run
...
the `?` operator can only be used in a function that returns `Result` or `Option` (or another type that implements `FromResidual`)
--> src/main.rs:4:48
|
3 | fn main() {
| --------- this function should return `Result` or `Option` to accept `?`
4 | let greeting_file = File::open("hello.txt")?;
| ^ cannot use the `?` operator in a function that returns `()`
|
= help: the trait `FromResidual<Result<Infallible, std::io::Error>>` is not implemented for `()`
因为 ? 要求 Result<T, E> 形式的返回值,而 main 函数的返回是 (),因此无法满足,那是不是就无解了呢?
实际上 Rust 还支持另外一种形式的 main 函数:
use std::error::Error; use std::fs::File; fn main() -> Result<(), Box<dyn Error>> { let f = File::open("hello.txt")?; Ok(()) }
这样就能使用 ? 提前返回了,同时我们又一次看到了Box<dyn Error> 特征对象,因为 std::error:Error 是 Rust 中抽象层次最高的错误,其它标准库中的错误都实现了该特征,因此我们可以用该特征对象代表一切错误,就算 main 函数中调用任何标准库函数发生错误,都可以通过 Box<dyn Error> 这个特征对象进行返回.
至于 main 函数可以有多种返回值,那是因为实现了 std::process::Termination 特征,目前为止该特征还没进入稳定版 Rust 中,也许未来你可以为自己的类型实现该特征!
try!
在 ? 横空出世之前( Rust 1.13 ),Rust 开发者还可以使用 try! 来处理错误,该宏的大致定义如下:
#![allow(unused)] fn main() { macro_rules! try { ($e:expr) => (match $e { Ok(val) => val, Err(err) => return Err(::std::convert::From::from(err)), }); } }
简单看一下与 ? 的对比:
#![allow(unused)] fn main() { // `?` let x = function_with_error()?; // 若返回 Err, 则立刻返回;若返回 Ok(255),则将 x 的值设置为 255 // `try!()` let x = try!(function_with_error()); }
可以看出 ? 的优势非常明显,何况 ? 还能做链式调用。
总之,try! 作为前浪已经死在了沙滩上,在当前版本中,我们要尽量避免使用 try!。
包和模块
当工程规模变大时,把代码写到一个甚至几个文件中,都是不太聪明的做法,可能存在以下问题:
- 单个文件过大,导致打开、翻页速度大幅变慢
- 查询和定位效率大幅降低,类比下,你会把所有知识内容放在一个几十万字的文档中吗?
- 只有一个代码层次:函数,难以维护和协作,想象一下你的操作系统只有一个根目录,剩下的都是单层子目录会如何:
disaster - 容易滋生 Bug
同时,将大的代码文件拆分成包和模块,还允许我们实现代码抽象和复用:将你的代码封装好后提供给用户,那么用户只需要调用公共接口即可,无需知道内部该如何实现。
因此,跟其它语言一样,Rust 也提供了相应概念用于代码的组织管理:
- 项目(Packages):一个
Cargo提供的feature,可以用来构建、测试和分享包 - 包(Crate):一个由多个模块组成的树形结构,可以作为三方库进行分发,也可以生成可执行文件进行运行
- 模块(Module):可以一个文件多个模块,也可以一个文件一个模块,模块可以被认为是真实项目中的代码组织单元
下面,让我们一一来学习这些概念以及如何在实践中运用。
包和 Package
真实项目远比我们之前的 cargo new 的默认目录结构要复杂,好在,Rust 为我们提供了强大的包管理工具:
- 项目(Package):可以用来构建、测试和分享包
- 工作空间(WorkSpace):对于大型项目,可以进一步将多个包联合在一起,组织成工作空间
- 包(Crate):一个由多个模块组成的树形结构,可以作为三方库进行分发,也可以生成可执行文件进行运行
- 模块(Module):可以一个文件多个模块,也可以一个文件一个模块,模块可以被认为是真实项目中的代码组织单元
定义
其实项目 Package 和包 Crate 很容易被搞混,甚至在很多书中,这两者都是不分的,但是由于官方对此做了明确的区分,因此我们会在本章节中试图(挣扎着)理清这个概念。
包 Crate
对于 Rust 而言,包是一个独立的可编译单元,它编译后会生成一个可执行文件或者一个库。
一个包会将相关联的功能打包在一起,使得该功能可以很方便的在多个项目中分享。例如标准库中没有提供但是在三方库中提供的 rand 包,它提供了随机数生成的功能,我们只需要将该包通过 use rand; 引入到当前项目的作用域中,就可以在项目中使用 rand 的功能:rand::XXX。
同一个包中不能有同名的类型,但是在不同包中就可以。例如,虽然 rand 包中,有一个 Rng 特征,可是我们依然可以在自己的项目中定义一个 Rng,前者通过 rand::Rng 访问,后者通过 Rng 访问,对于编译器而言,这两者的边界非常清晰,不会存在引用歧义。
项目 Package
鉴于 Rust 团队标新立异的起名传统,以及包的名称被 crate 占用,库的名称被 library 占用,经过斟酌, 我们决定将 Package 翻译成项目,你也可以理解为工程、软件包。
由于 Package 就是一个项目,因此它包含有独立的 Cargo.toml 文件,以及因为功能性被组织在一起的一个或多个包。一个 Package 只能包含一个库(library)类型的包,但是可以包含多个二进制可执行类型的包。
二进制 Package
让我们来创建一个二进制 Package:
$ cargo new my-project
Created binary (application) `my-project` package
$ ls my-project
Cargo.toml
src
$ ls my-project/src
main.rs
这里,Cargo 为我们创建了一个名称是 my-project 的 Package,同时在其中创建了 Cargo.toml 文件,可以看一下该文件,里面并没有提到 src/main.rs 作为程序的入口,原因是 Cargo 有一个惯例:src/main.rs 是二进制包的根文件,该二进制包的包名跟所属 Package 相同,在这里都是 my-project,所有的代码执行都从该文件中的 fn main() 函数开始。
使用 cargo run 可以运行该项目,输出:Hello, world!。
库 Package
再来创建一个库类型的 Package:
$ cargo new my-lib --lib
Created library `my-lib` package
$ ls my-lib
Cargo.toml
src
$ ls my-lib/src
lib.rs
首先,如果你试图运行 my-lib,会报错:
$ cargo run
error: a bin target must be available for `cargo run`
原因是库类型的 Package 只能作为三方库被其它项目引用,而不能独立运行,只有之前的二进制 Package 才可以运行。
与 src/main.rs 一样,Cargo 知道,如果一个 Package 包含有 src/lib.rs,意味它包含有一个库类型的同名包 my-lib,该包的根文件是 src/lib.rs。
易混淆的 Package 和包
看完上面,相信大家看出来为何 Package 和包容易被混淆了吧?因为你用 cargo new 创建的 Package 和它其中包含的包是同名的!
不过,只要你牢记 Package 是一个项目工程,而包只是一个编译单元,基本上也就不会混淆这个两个概念了:src/main.rs 和 src/lib.rs 都是编译单元,因此它们都是包。
典型的 Package 结构
上面创建的 Package 中仅包含 src/main.rs 文件,意味着它仅包含一个二进制同名包 my-project。如果一个 Package 同时拥有 src/main.rs 和 src/lib.rs,那就意味着它包含两个包:库包和二进制包,这两个包名也都是 my-project —— 都与 Package 同名。
一个真实项目中典型的 Package,会包含多个二进制包,这些包文件被放在 src/bin 目录下,每一个文件都是独立的二进制包,同时也会包含一个库包,该包只能存在一个 src/lib.rs:
.
├── Cargo.toml
├── Cargo.lock
├── src
│ ├── main.rs
│ ├── lib.rs
│ └── bin
│ └── main1.rs
│ └── main2.rs
├── tests
│ └── some_integration_tests.rs
├── benches
│ └── simple_bench.rs
└── examples
└── simple_example.rs
- 唯一库包:
src/lib.rs - 默认二进制包:
src/main.rs,编译后生成的可执行文件与Package同名 - 其余二进制包:
src/bin/main1.rs和src/bin/main2.rs,它们会分别生成一个文件同名的二进制可执行文件 - 集成测试文件:
tests目录下 - 基准性能测试
benchmark文件:benches目录下 - 项目示例:
examples目录下
这种目录结构基本上是 Rust 的标准目录结构,在 GitHub 的大多数项目上,你都将看到它的身影。
理解了包的概念,我们再来看看构成包的基本单元:模块。
模块 Module
在本章节,我们将深入讲讲 Rust 的代码构成单元:模块。使用模块可以将包中的代码按照功能性进行重组,最终实现更好的可读性及易用性。同时,我们还能非常灵活地去控制代码的可见性,进一步强化 Rust 的安全性。
创建嵌套模块
小旅馆,sorry,是小餐馆,相信大家都挺熟悉的,学校外的估计也没少去,那么咱就用小餐馆为例,来看看 Rust 的模块该如何使用。
使用 cargo new --lib restaurant 创建一个小餐馆,注意,这里创建的是一个库类型的 Package,然后将以下代码放入 src/lib.rs 中:
#![allow(unused)] fn main() { // 餐厅前厅,用于吃饭 mod front_of_house { mod hosting { fn add_to_waitlist() {} fn seat_at_table() {} } mod serving { fn take_order() {} fn serve_order() {} fn take_payment() {} } } }
以上的代码创建了三个模块,有几点需要注意的:
- 使用
mod关键字来创建新模块,后面紧跟着模块名称 - 模块可以嵌套,这里嵌套的原因是招待客人和服务都发生在前厅,因此我们的代码模拟了真实场景
- 模块中可以定义各种 Rust 类型,例如函数、结构体、枚举、特征等
- 所有模块均定义在同一个文件中
类似上述代码中所做的,使用模块,我们就能将功能相关的代码组织到一起,然后通过一个模块名称来说明这些代码为何被组织在一起。这样其它程序员在使用你的模块时,就可以更快地理解和上手。
模块树
在上一节中,我们提到过 src/main.rs 和 src/lib.rs 被称为包根(crate root),这个奇葩名称的来源(我不想承认是自己翻译水平太烂-,-)是由于这两个文件的内容形成了一个模块 crate,该模块位于包的树形结构(由模块组成的树形结构)的根部:
crate
└── front_of_house
├── hosting
│ ├── add_to_waitlist
│ └── seat_at_table
└── serving
├── take_order
├── serve_order
└── take_payment
这颗树展示了模块之间彼此的嵌套关系,因此被称为模块树。其中 crate 包根是 src/lib.rs 文件,包根文件中的三个模块分别形成了模块树的剩余部分。
父子模块
如果模块 A 包含模块 B,那么 A 是 B 的父模块,B 是 A 的子模块。在上例中,front_of_house 是 hosting 和 serving 的父模块,反之,后两者是前者的子模块。
聪明的读者,应该能联想到,模块树跟计算机上文件系统目录树的相似之处。不仅仅是组织结构上的相似,就连使用方式都很相似:每个文件都有自己的路径,用户可以通过这些路径使用它们,在 Rust 中,我们也通过路径的方式来引用模块。
用路径引用模块
想要调用一个函数,就需要知道它的路径,在 Rust 中,这种路径有两种形式:
- 绝对路径,从包根开始,路径名以包名或者
crate作为开头 - 相对路径,从当前模块开始,以
self,super或当前模块的标识符作为开头
让我们继续经营那个惨淡的小餐馆,这次为它实现一个小功能: 文件名:src/lib.rs
#![allow(unused)] fn main() { mod front_of_house { mod hosting { fn add_to_waitlist() {} } } pub fn eat_at_restaurant() { // 绝对路径 crate::front_of_house::hosting::add_to_waitlist(); // 相对路径 front_of_house::hosting::add_to_waitlist(); } }
上面的代码为了简化实现,省去了其余模块和函数,这样可以把关注点放在函数调用上。eat_at_restaurant 是一个定义在包根中的函数,在该函数中使用了两种方式对 add_to_waitlist 进行调用。
绝对路径引用
因为 eat_at_restaurant 和 add_to_waitlist 都定义在一个包中,因此在绝对路径引用时,可以直接以 crate 开头,然后逐层引用,每一层之间使用 :: 分隔:
#![allow(unused)] fn main() { crate::front_of_house::hosting::add_to_waitlist(); }
对比下之前的模块树:
crate
└── eat_at_restaurant
└── front_of_house
├── hosting
│ ├── add_to_waitlist
│ └── seat_at_table
└── serving
├── take_order
├── serve_order
└── take_payment
可以看出,绝对路径的调用,完全符合了模块树的层级递进,非常符合直觉,如果类比文件系统,就跟使用绝对路径调用可执行程序差不多:/front_of_house/hosting/add_to_waitlist,使用 crate 作为开始就和使用 / 作为开始一样。
相对路径引用
再回到模块树中,因为 eat_at_restaurant 和 front_of_house 都处于包根 crate 中,因此相对路径可以使用 front_of_house 作为开头:
#![allow(unused)] fn main() { front_of_house::hosting::add_to_waitlist(); }
如果类比文件系统,那么它类似于调用同一个目录下的程序,你可以这么做:front_of_house/hosting/add_to_waitlist,嗯也很符合直觉。
绝对还是相对?
如果只是为了引用到指定模块中的对象,那么两种都可以,但是在实际使用时,需要遵循一个原则:当代码被挪动位置时,尽量减少引用路径的修改,相信大家都遇到过,修改了某处代码,导致所有路径都要挨个替换,这显然不是好的路径选择。
回到之前的例子,如果我们把 front_of_house 模块和 eat_at_restaurant 移动到一个模块中 customer_experience,那么绝对路径的引用方式就必须进行修改:crate::customer_experience::front_of_house ...,但是假设我们使用的相对路径,那么该路径就无需修改,因为它们两个的相对位置其实没有变:
crate
└── customer_experience
└── eat_at_restaurant
└── front_of_house
├── hosting
│ ├── add_to_waitlist
│ └── seat_at_table
从新的模块树中可以很清晰的看出这一点。
再比如,其它的都不动,把 eat_at_restaurant 移动到模块 dining 中,如果使用相对路径,你需要修改该路径,但如果使用的是绝对路径,就无需修改:
crate
└── dining
└── eat_at_restaurant
└── front_of_house
├── hosting
│ ├── add_to_waitlist
不过,如果不确定哪个好,你可以考虑优先使用绝对路径,因为调用的地方和定义的地方往往是分离的,而定义的地方较少会变动。
代码可见性
让我们运行下面(之前)的代码:
#![allow(unused)] fn main() { mod front_of_house { mod hosting { fn add_to_waitlist() {} } } pub fn eat_at_restaurant() { // 绝对路径 crate::front_of_house::hosting::add_to_waitlist(); // 相对路径 front_of_house::hosting::add_to_waitlist(); } }
意料之外的报错了,毕竟看上去确实很简单且没有任何问题:
error[E0603]: module `hosting` is private
--> src/lib.rs:9:28
|
9 | crate::front_of_house::hosting::add_to_waitlist();
| ^^^^^^^ private module
错误信息很清晰:hosting 模块是私有的,无法在包根进行访问,那么为何 front_of_house 模块就可以访问?因为它和 eat_at_restaurant 同属于一个包根作用域内,同一个模块内的代码自然不存在私有化问题(所以我们之前章节的代码都没有报过这个错误!)。
模块不仅仅对于组织代码很有用,它还能定义代码的私有化边界:在这个边界内,什么内容能让外界看到,什么内容不能,都有很明确的定义。因此,如果希望让函数或者结构体等类型变成私有化的,可以使用模块。
Rust 出于安全的考虑,默认情况下,所有的类型都是私有化的,包括函数、方法、结构体、枚举、常量,是的,就连模块本身也是私有化的。在中国,父亲往往不希望孩子拥有小秘密,但是在 Rust 中,父模块完全无法访问子模块中的私有项,但是子模块却可以访问父模块、父父..模块的私有项。
pub 关键字
类似其它语言的 public 或者 Go 语言中的首字母大写,Rust 提供了 pub 关键字,通过它你可以控制模块和模块中指定项的可见性。
由于之前的解释,我们知道了只需要将 hosting 模块标记为对外可见即可:
#![allow(unused)] fn main() { mod front_of_house { pub mod hosting { fn add_to_waitlist() {} } } /*--- snip ----*/ }
但是不幸的是,又报错了:
error[E0603]: function `add_to_waitlist` is private
--> src/lib.rs:12:30
|
12 | front_of_house::hosting::add_to_waitlist();
| ^^^^^^^^^^^^^^^ private function
哦?难道模块可见还不够,还需要将函数 add_to_waitlist 标记为可见的吗? 是的,没错,模块可见性不代表模块内部项的可见性,模块的可见性仅仅是允许其它模块去引用它,但是想要引用它内部的项,还得继续将对应的项标记为 pub。
在实际项目中,一个模块需要对外暴露的数据和 API 往往就寥寥数个,如果将模块标记为可见代表着内部项也全部对外可见,那你是不是还得把那些不可见的,一个一个标记为 private?反而是更麻烦的多。
既然知道了如何解决,那么我们为函数也标记上 pub:
#![allow(unused)] fn main() { mod front_of_house { pub mod hosting { pub fn add_to_waitlist() {} } } /*--- snip ----*/ }
Bang,顺利通过编译,感觉自己又变强了。
使用 super 引用模块
在用路径引用模块中,我们提到了相对路径有三种方式开始:self、super和 crate 或者模块名,其中第三种在前面已经讲到过,现在来看看通过 super 的方式引用模块项。
super 代表的是父模块为开始的引用方式,非常类似于文件系统中的 .. 语法:../a/b
文件名:src/lib.rs
#![allow(unused)] fn main() { fn serve_order() {} // 厨房模块 mod back_of_house { fn fix_incorrect_order() { cook_order(); super::serve_order(); } fn cook_order() {} } }
嗯,我们的小餐馆又完善了,终于有厨房了!看来第一个客人也快可以有了。。。在厨房模块中,使用 super::serve_order 语法,调用了父模块(包根)中的 serve_order 函数。
那么你可能会问,为何不使用 crate::serve_order 的方式?额,其实也可以,不过如果你确定未来这种层级关系不会改变,那么 super::serve_order 的方式会更稳定,未来就算它们都不在包根了,依然无需修改引用路径。所以路径的选用,往往还是取决于场景,以及未来代码的可能走向。
使用 self 引用模块
self 其实就是引用自身模块中的项,也就是说和我们之前章节的代码类似,都调用同一模块中的内容,区别在于之前章节中直接通过名称调用即可,而 self,你得多此一举:
#![allow(unused)] fn main() { fn serve_order() { self::back_of_house::cook_order() } mod back_of_house { fn fix_incorrect_order() { cook_order(); crate::serve_order(); } pub fn cook_order() {} } }
是的,多此一举,因为完全可以直接调用 back_of_house,但是 self 还有一个大用处,在下一节中我们会讲。
结构体和枚举的可见性
为何要把结构体和枚举的可见性单独拎出来讲呢?因为这两个家伙的成员字段拥有完全不同的可见性:
- 将结构体设置为
pub,但它的所有字段依然是私有的 - 将枚举设置为
pub,它的所有字段也将对外可见
原因在于,枚举和结构体的使用方式不一样。如果枚举的成员对外不可见,那该枚举将一点用都没有,因此枚举成员的可见性自动跟枚举可见性保持一致,这样可以简化用户的使用。
而结构体的应用场景比较复杂,其中的字段也往往部分在 A 处被使用,部分在 B 处被使用,因此无法确定成员的可见性,那索性就设置为全部不可见,将选择权交给程序员。
模块与文件分离
在之前的例子中,我们所有的模块都定义在 src/lib.rs 中,但是当模块变多或者变大时,需要将模块放入一个单独的文件中,让代码更好维护。
现在,把 front_of_house 前厅分离出来,放入一个单独的文件中 src/front_of_house.rs:
#![allow(unused)] fn main() { pub mod hosting { pub fn add_to_waitlist() {} } }
然后,将以下代码留在 src/lib.rs 中:
#![allow(unused)] fn main() { mod front_of_house; pub use crate::front_of_house::hosting; pub fn eat_at_restaurant() { hosting::add_to_waitlist(); hosting::add_to_waitlist(); hosting::add_to_waitlist(); } }
so easy!其实跟之前在同一个文件中也没有太大的不同,但是有几点值得注意:
mod front_of_house;告诉 Rust 从另一个和模块front_of_house同名的文件中加载该模块的内容- 使用绝对路径的方式来引用
hosting模块:crate::front_of_house::hosting;
需要注意的是,和之前代码中 mod front_of_house{..} 的完整模块不同,现在的代码中,模块的声明和实现是分离的,实现是在单独的 front_of_house.rs 文件中,然后通过 mod front_of_house; 这条声明语句从该文件中把模块内容加载进来。因此我们可以认为,模块 front_of_house 的定义还是在 src/lib.rs 中,只不过模块的具体内容被移动到了 src/front_of_house.rs 文件中。
在这里出现了一个新的关键字 use,联想到其它章节我们见过的标准库引入 use std::fmt;,可以大致猜测,该关键字用来将外部模块中的项引入到当前作用域中来,这样无需冗长的父模块前缀即可调用:hosting::add_to_waitlist();,在下节中,我们将对 use 进行详细的讲解。
当一个模块有许多子模块时,我们也可以通过文件夹的方式来组织这些子模块。
在上述例子中,我们可以创建一个目录 front_of_house,然后在文件夹里创建一个 hosting.rs 文件,hosting.rs 文件现在就剩下:
#![allow(unused)] fn main() { pub fn add_to_waitlist() {} }
现在,我们尝试编译程序,很遗憾,编译器报错:
error[E0583]: file not found for module `front_of_house`
--> src/lib.rs:3:1
|
1 | mod front_of_house;
| ^^^^^^^^^^^^^^^^^^
|
= help: to create the module `front_of_house`, create file "src/front_of_house.rs" or "src/front_of_house/mod.rs"
是的,如果需要将文件夹作为一个模块,我们需要进行显示指定暴露哪些子模块。按照上述的报错信息,我们有两种方法:
- 在
front_of_house目录里创建一个mod.rs,如果你使用的rustc版本1.30之前,这是唯一的方法。 - 在
front_of_house同级目录里创建一个与模块(目录)同名的 rs 文件front_of_house.rs,在新版本里,更建议使用这样的命名方式来避免项目中存在大量同名的mod.rs文件( Python 点了个踩)。
而无论是上述哪个方式创建的文件,其内容都是一样的,你需要定义你的子模块(子模块名与文件名相同):
#![allow(unused)] fn main() { pub mod hosting; // pub mod serving; }
使用 use 及受限可见性
如果代码中,通篇都是 crate::front_of_house::hosting::add_to_waitlist 这样的函数调用形式,我不知道有谁会喜欢,也许靠代码行数赚工资的人会很喜欢,但是强迫症肯定受不了,悲伤的是程序员大多都有强迫症。。。
因此我们需要一个办法来简化这种使用方式,在 Rust 中,可以使用 use 关键字把路径提前引入到当前作用域中,随后的调用就可以省略该路径,极大地简化了代码。
基本引入方式
在 Rust 中,引入模块中的项有两种方式:绝对路径和相对路径,这两者在前面章节都有讲过,就不再赘述,先来看看使用绝对路径的引入方式。
绝对路径引入模块
#![allow(unused)] fn main() { mod front_of_house { pub mod hosting { pub fn add_to_waitlist() {} } } use crate::front_of_house::hosting; pub fn eat_at_restaurant() { hosting::add_to_waitlist(); hosting::add_to_waitlist(); hosting::add_to_waitlist(); } }
这里,我们使用 use 和绝对路径的方式,将 hosting 模块引入到当前作用域中,然后只需通过 hosting::add_to_waitlist 的方式,即可调用目标模块中的函数,相比 crate::front_of_house::hosting::add_to_waitlist() 的方式要简单的多,那么还能更简单吗?
相对路径引入模块中的函数
在下面代码中,我们不仅要使用相对路径进行引入,而且与上面引入 hosting 模块不同,直接引入该模块中的 add_to_waitlist 函数:
#![allow(unused)] fn main() { mod front_of_house { pub mod hosting { pub fn add_to_waitlist() {} } } use front_of_house::hosting::add_to_waitlist; pub fn eat_at_restaurant() { add_to_waitlist(); add_to_waitlist(); add_to_waitlist(); } }
很明显,三兄弟又变得更短了,不过,怎么觉得这句话怪怪的。。
引入模块还是函数
从使用简洁性来说,引入函数自然是更甚一筹,但是在某些时候,引入模块会更好:
- 需要引入同一个模块的多个函数
- 作用域中存在同名函数
在以上两种情况中,使用 use front_of_house::hosting 引入模块要比 use front_of_house::hosting::add_to_waitlist; 引入函数更好。
例如,如果想使用 HashMap,那么直接引入该结构体是比引入模块更好的选择,因为在 collections 模块中,我们只需要使用一个 HashMap 结构体:
use std::collections::HashMap; fn main() { let mut map = HashMap::new(); map.insert(1, 2); }
其实严格来说,对于引用方式并没有需要遵守的惯例,主要还是取决于你的喜好,不过我们建议:优先使用最细粒度(引入函数、结构体等)的引用方式,如果引起了某种麻烦(例如前面两种情况),再使用引入模块的方式。
避免同名引用
根据上一章节的内容,我们只要保证同一个模块中不存在同名项就行,模块之间、包之间的同名,谁管得着谁啊,话虽如此,一起看看,如果遇到同名的情况该如何处理。
模块::函数
#![allow(unused)] fn main() { use std::fmt; use std::io; fn function1() -> fmt::Result { // --snip-- } fn function2() -> io::Result<()> { // --snip-- } }
上面的例子给出了很好的解决方案,使用模块引入的方式,具体的 Result 通过 模块::Result 的方式进行调用。
可以看出,避免同名冲突的关键,就是使用父模块的方式来调用,除此之外,还可以给予引入的项起一个别名。
as 别名引用
对于同名冲突问题,还可以使用 as 关键字来解决,它可以赋予引入项一个全新的名称:
#![allow(unused)] fn main() { use std::fmt::Result; use std::io::Result as IoResult; fn function1() -> Result { // --snip-- } fn function2() -> IoResult<()> { // --snip-- } }
如上所示,首先通过 use std::io::Result 将 Result 引入到作用域,然后使用 as 给予它一个全新的名称 IoResult,这样就不会再产生冲突:
Result代表std::fmt::ResultIoResult代表std:io::Result
引入项再导出
当外部的模块项 A 被引入到当前模块中时,它的可见性自动被设置为私有的,如果你希望允许其它外部代码引用我们的模块项 A,那么可以对它进行再导出:
#![allow(unused)] fn main() { mod front_of_house { pub mod hosting { pub fn add_to_waitlist() {} } } pub use crate::front_of_house::hosting; pub fn eat_at_restaurant() { hosting::add_to_waitlist(); hosting::add_to_waitlist(); hosting::add_to_waitlist(); } }
如上,使用 pub use 即可实现。这里 use 代表引入 hosting 模块到当前作用域,pub 表示将该引入的内容再度设置为可见。
当你希望将内部的实现细节隐藏起来或者按照某个目的组织代码时,可以使用 pub use 再导出,例如统一使用一个模块来提供对外的 API,那该模块就可以引入其它模块中的 API,然后进行再导出,最终对于用户来说,所有的 API 都是由一个模块统一提供的。
使用第三方包
之前我们一直在引入标准库模块或者自定义模块,现在来引入下第三方包中的模块,关于如何引入外部依赖,这里直接给出操作步骤:
- 修改
Cargo.toml文件,在[dependencies]区域添加一行:rand = "0.8.3" - 此时,如果你用的是
VSCode和rust-analyzer插件,该插件会自动拉取该库,你可能需要等它完成后,再进行下一步(VSCode 左下角有提示)
好了,此时,rand 包已经被我们添加到依赖中,下一步就是在代码中使用:
use rand::Rng; fn main() { let secret_number = rand::thread_rng().gen_range(1..101); }
这里使用 use 引入了第三方包 rand 中的 Rng 特征,因为我们需要调用的 gen_range 方法定义在该特征中。
crates.io,lib.rs
Rust 社区已经为我们贡献了大量高质量的第三方包,你可以在 crates.io 或者 lib.rs 中检索和使用,从目前来说查找包更推荐 lib.rs,搜索功能更强大,内容展示也更加合理,但是下载依赖包还是得用crates.io。
你可以在网站上搜索 rand 包,看看它的文档使用方式是否和我们之前引入方式相一致:在网上找到想要的包,然后将你想要的包和版本信息写入到 Cargo.toml 中。
使用 {} 简化引入方式
对于以下一行一行的引入方式:
#![allow(unused)] fn main() { use std::collections::HashMap; use std::collections::BTreeMap; use std::collections::HashSet; use std::cmp::Ordering; use std::io; }
可以使用 {} 来一起引入进来,在大型项目中,使用这种方式来引入,可以减少大量 use 的使用:
#![allow(unused)] fn main() { use std::collections::{HashMap,BTreeMap,HashSet}; use std::{cmp::Ordering, io}; }
对于下面的同时引入模块和模块中的项:
#![allow(unused)] fn main() { use std::io; use std::io::Write; }
可以使用 {} 的方式进行简化:
#![allow(unused)] fn main() { use std::io::{self, Write}; }
self
上面使用到了模块章节提到的 self 关键字,用来替代模块自身,结合上一节中的 self,可以得出它在模块中的两个用途:
use self::xxx,表示加载当前模块中的xxx。此时self可省略use xxx::{self, yyy},表示,加载当前路径下模块xxx本身,以及模块xxx下的yyy
使用 * 引入模块下的所有项
对于之前一行一行引入 std::collections 的方式,我们还可以使用
#![allow(unused)] fn main() { use std::collections::*; }
以上这种方式来引入 std::collections 模块下的所有公共项,这些公共项自然包含了 HashMap,HashSet 等想手动引入的集合类型。
当使用 * 来引入的时候要格外小心,因为你很难知道到底哪些被引入到了当前作用域中,有哪些会和你自己程序中的名称相冲突:
use std::collections::*; struct HashMap; fn main() { let mut v = HashMap::new(); v.insert("a", 1); }
以上代码中,std::collection::HashMap 被 * 引入到当前作用域,但是由于存在另一个同名的结构体,因此 HashMap::new 根本不存在,因为对于编译器来说,本地同名类型的优先级更高。
在实际项目中,这种引用方式往往用于快速写测试代码,它可以把所有东西一次性引入到 tests 模块中。
受限的可见性
在上一节中,我们学习了可见性这个概念,这也是模块体系中最为核心的概念,控制了模块中哪些内容可以被外部看见,但是在实际使用时,光被外面看到还不行,我们还想控制哪些人能看,这就是 Rust 提供的受限可见性。
例如,在 Rust 中,包是一个模块树,我们可以通过 pub(crate) item; 这种方式来实现:item 虽然是对外可见的,但是只在当前包内可见,外部包无法引用到该 item。
所以,如果我们想要让某一项可以在整个包中都可以被使用,那么有两种办法:
- 在包根中定义一个非
pub类型的X(父模块的项对子模块都是可见的,因此包根中的项对模块树上的所有模块都可见) - 在子模块中定义一个
pub类型的Y,同时通过use将其引入到包根
#![allow(unused)] fn main() { mod a { pub mod b { pub fn c() { println!("{:?}",crate::X); } #[derive(Debug)] pub struct Y; } } #[derive(Debug)] struct X; use a::b::Y; fn d() { println!("{:?}",Y); } }
以上代码充分说明了之前两种办法的使用方式,但是有时我们会遇到这两种方法都不太好用的时候。例如希望对于某些特定的模块可见,但是对于其他模块又不可见:
#![allow(unused)] fn main() { // 目标:`a` 导出 `I`、`bar` and `foo`,其他的不导出 pub mod a { pub const I: i32 = 3; fn semisecret(x: i32) -> i32 { use self::b::c::J; x + J } pub fn bar(z: i32) -> i32 { semisecret(I) * z } pub fn foo(y: i32) -> i32 { semisecret(I) + y } mod b { mod c { const J: i32 = 4; } } } }
这段代码会报错,因为与父模块中的项对子模块可见相反,子模块中的项对父模块是不可见的。这里 semisecret 方法中,a -> b -> c 形成了父子模块链,那 c 中的 J 自然对 a 模块不可见。
如果使用之前的可见性方式,那么想保持 J 私有,同时让 a 继续使用 semisecret 函数的办法是将该函数移动到 c 模块中,然后用 pub use 将 semisecret 函数进行再导出:
#![allow(unused)] fn main() { pub mod a { pub const I: i32 = 3; use self::b::semisecret; pub fn bar(z: i32) -> i32 { semisecret(I) * z } pub fn foo(y: i32) -> i32 { semisecret(I) + y } mod b { pub use self::c::semisecret; mod c { const J: i32 = 4; pub fn semisecret(x: i32) -> i32 { x + J } } } } }
这段代码说实话问题不大,但是有些破坏了我们之前的逻辑,如果想保持代码逻辑,同时又只让 J 在 a 内可见该怎么办?
#![allow(unused)] fn main() { pub mod a { pub const I: i32 = 3; fn semisecret(x: i32) -> i32 { use self::b::c::J; x + J } pub fn bar(z: i32) -> i32 { semisecret(I) * z } pub fn foo(y: i32) -> i32 { semisecret(I) + y } mod b { pub(in crate::a) mod c { pub(in crate::a) const J: i32 = 4; } } } }
通过 pub(in crate::a) 的方式,我们指定了模块 c 和常量 J 的可见范围都只是 a 模块中,a 之外的模块是完全访问不到它们的。
限制可见性语法
pub(crate) 或 pub(in crate::a) 就是限制可见性语法,前者是限制在整个包内可见,后者是通过绝对路径,限制在包内的某个模块内可见,总结一下:
pub意味着可见性无任何限制pub(crate)表示在当前包可见pub(self)在当前模块可见pub(super)在父模块可见pub(in <path>)表示在某个路径代表的模块中可见,其中path必须是父模块或者祖先模块
一个综合例子
// 一个名为 `my_mod` 的模块 mod my_mod { // 模块中的项默认具有私有的可见性 fn private_function() { println!("called `my_mod::private_function()`"); } // 使用 `pub` 修饰语来改变默认可见性。 pub fn function() { println!("called `my_mod::function()`"); } // 在同一模块中,项可以访问其它项,即使它是私有的。 pub fn indirect_access() { print!("called `my_mod::indirect_access()`, that\n> "); private_function(); } // 模块也可以嵌套 pub mod nested { pub fn function() { println!("called `my_mod::nested::function()`"); } #[allow(dead_code)] fn private_function() { println!("called `my_mod::nested::private_function()`"); } // 使用 `pub(in path)` 语法定义的函数只在给定的路径中可见。 // `path` 必须是父模块(parent module)或祖先模块(ancestor module) pub(in crate::my_mod) fn public_function_in_my_mod() { print!("called `my_mod::nested::public_function_in_my_mod()`, that\n > "); public_function_in_nested() } // 使用 `pub(self)` 语法定义的函数则只在当前模块中可见。 pub(self) fn public_function_in_nested() { println!("called `my_mod::nested::public_function_in_nested"); } // 使用 `pub(super)` 语法定义的函数只在父模块中可见。 pub(super) fn public_function_in_super_mod() { println!("called my_mod::nested::public_function_in_super_mod"); } } pub fn call_public_function_in_my_mod() { print!("called `my_mod::call_public_funcion_in_my_mod()`, that\n> "); nested::public_function_in_my_mod(); print!("> "); nested::public_function_in_super_mod(); } // `pub(crate)` 使得函数只在当前包中可见 pub(crate) fn public_function_in_crate() { println!("called `my_mod::public_function_in_crate()"); } // 嵌套模块的可见性遵循相同的规则 mod private_nested { #[allow(dead_code)] pub fn function() { println!("called `my_mod::private_nested::function()`"); } } } fn function() { println!("called `function()`"); } fn main() { // 模块机制消除了相同名字的项之间的歧义。 function(); my_mod::function(); // 公有项,包括嵌套模块内的,都可以在父模块外部访问。 my_mod::indirect_access(); my_mod::nested::function(); my_mod::call_public_function_in_my_mod(); // pub(crate) 项可以在同一个 crate 中的任何地方访问 my_mod::public_function_in_crate(); // pub(in path) 项只能在指定的模块中访问 // 报错!函数 `public_function_in_my_mod` 是私有的 //my_mod::nested::public_function_in_my_mod(); // 试一试 ^ 取消该行的注释 // 模块的私有项不能直接访问,即便它是嵌套在公有模块内部的 // 报错!`private_function` 是私有的 //my_mod::private_function(); // 试一试 ^ 取消此行注释 // 报错!`private_function` 是私有的 //my_mod::nested::private_function(); // 试一试 ^ 取消此行的注释 // 报错! `private_nested` 是私有的 //my_mod::private_nested::function(); // 试一试 ^ 取消此行的注释 }
注释和文档
在之前的章节我们学习了包和模块如何使用,在此章节将进一步学习如何书写文档注释,以及如何使用 cargo doc 生成项目的文档,最后将以一个包、模块和文档的综合性例子,来将这些知识融会贯通。
注释的种类
在 Rust 中,注释分为三类:
- 代码注释,用于说明某一块代码的功能,读者往往是同一个项目的协作开发者
- 文档注释,支持
Markdown,对项目描述、公共 API 等用户关心的功能进行介绍,同时还能提供示例代码,目标读者往往是想要了解你项目的人 - 包和模块注释,严格来说这也是文档注释中的一种,它主要用于说明当前包和模块的功能,方便用户迅速了解一个项目
通过这些注释,实现了 Rust 极其优秀的文档化支持,甚至你还能在文档注释中写测试用例,省去了单独写测试用例的环节,我直呼好家伙!
代码注释
显然之前的刮目相看是打了引号的,想要去掉引号,该写注释的时候,就老老实实的,不过写时需要遵循八字原则:围绕目标,言简意赅,记住,洋洋洒洒那是用来形容文章的,不是形容注释!
代码注释方式有两种:
行注释 //
fn main() { // 我是Sun... // face let name = "sunface"; let age = 18; // 今年好像是18岁 }
如上所示,行注释可以放在某一行代码的上方,也可以放在当前代码行的后方。如果超出一行的长度,需要在新行的开头也加上 //。
当注释行数较多时,你还可以使用块注释
块注释/* ..... */
fn main() { /* 我 是 S u n ... 哎,好长! */ let name = "sunface"; let age = "???"; // 今年其实。。。挺大了 }
如上所示,只需要将注释内容使用 /* */ 进行包裹即可。
你会发现,Rust 的代码注释跟其它语言并没有区别,主要区别其实在于文档注释这一块,也是本章节内容的重点。
文档注释
当查看一个 crates.io 上的包时,往往需要通过它提供的文档来浏览相关的功能特性、使用方式,这种文档就是通过文档注释实现的。
Rust 提供了 cargo doc 的命令,可以用于把这些文档注释转换成 HTML 网页文件,最终展示给用户浏览,这样用户就知道这个包是做什么的以及该如何使用。
文档行注释 ///
本书的一大特点就是废话不多,因此我们开门见山:
#![allow(unused)] fn main() { /// `add_one` 将指定值加1 /// /// # Examples /// /// ``` /// let arg = 5; /// let answer = my_crate::add_one(arg); /// /// assert_eq!(6, answer); /// ``` pub fn add_one(x: i32) -> i32 { x + 1 } }
以上代码有几点需要注意:
- 文档注释需要位于
lib类型的包中,例如src/lib.rs中 - 文档注释可以使用
markdown语法!例如# Examples的标题,以及代码块高亮 - 被注释的对象需要使用
pub对外可见,记住:文档注释是给用户看的,内部实现细节不应该被暴露出去
咦?文档注释中的例子,为什看上去像是能运行的样子?竟然还是有 assert_eq 这种常用于测试目的的宏。 嗯,你的感觉没错,详细内容会在本章后面讲解,容我先卖个关子。
文档块注释 /** ... */
与代码注释一样,文档也有块注释,当注释内容多时,使用块注释可以减少 /// 的使用:
#![allow(unused)] fn main() { /** `add_two` 将指定值加2 Examples ``` let arg = 5; let answer = my_crate::add_two(arg); assert_eq!(7, answer); ``` */ pub fn add_two(x: i32) -> i32 { x + 2 } }
查看文档 cargo doc
锦衣不夜行,这是中国人的传统美德。我们写了这么漂亮的文档注释,当然要看看网页中是什么效果咯。
很简单,运行 cargo doc 可以直接生成 HTML 文件,放入target/doc目录下。
当然,为了方便,我们使用 cargo doc --open 命令,可以在生成文档后,自动在浏览器中打开网页,最终效果如图所示:
非常棒,而且非常简单,这就是 Rust 工具链的强大之处。
常用文档标题
之前我们见到了在文档注释中该如何使用 markdown,其中包括 # Examples 标题。除了这个标题,还有一些常用的,你可以在项目中酌情使用:
- Panics:函数可能会出现的异常状况,这样调用函数的人就可以提前规避
- Errors:描述可能出现的错误及什么情况会导致错误,有助于调用者针对不同的错误采取不同的处理方式
- Safety:如果函数使用
unsafe代码,那么调用者就需要注意一些使用条件,以确保unsafe代码块的正常工作
话说回来,这些标题更多的是一种惯例,如果你非要用中文标题也没问题,但是最好在团队中保持同样的风格 :)
包和模块级别的注释
除了函数、结构体等 Rust 项的注释,你还可以给包和模块添加注释,需要注意的是,这些注释要添加到包、模块的最上方!
与之前的任何注释一样,包级别的注释也分为两种:行注释 //! 和块注释 /*! ... */。
现在,为我们的包增加注释,在 src/lib.rs 包根的最上方,添加:
#![allow(unused)] fn main() { /*! lib包是world_hello二进制包的依赖包, 里面包含了compute等有用模块 */ pub mod compute; }
然后再为该包根的子模块 src/compute.rs 添加注释:
#![allow(unused)] fn main() { //! 计算一些你口算算不出来的复杂算术题 /// `add_one`将指定值加1 /// }
运行 cargo doc --open 查看下效果:
包模块注释,可以让用户从整体的角度理解包的用途,对于用户来说是非常友好的,就和一篇文章的开头一样,总是要对文章的内容进行大致的介绍,让用户在看的时候心中有数。
至此,关于如何注释的内容,就结束了,那么注释还能用来做什么?可以玩出花来吗?答案是Yes.
文档测试(Doc Test)
相信读者之前都写过单元测试用例,其中一个很蛋疼的问题就是,随着代码的进化,单元测试用例经常会失效,过段时间后(为何是过段时间?应该这么问,有几个开发喜欢写测试用例 =,=),你发现需要连续修改不少处代码,才能让测试重新工作起来。然而,在 Rust 中,大可不必。
在之前的 add_one 中,我们写的示例代码非常像是一个单元测试的用例,这是偶然吗?并不是。因为 Rust 允许我们在文档注释中写单元测试用例!方法就如同之前做的:
#![allow(unused)] fn main() { /// `add_one` 将指定值加1 /// /// # Examples11 /// /// ``` /// let arg = 5; /// let answer = world_hello::compute::add_one(arg); /// /// assert_eq!(6, answer); /// ``` pub fn add_one(x: i32) -> i32 { x + 1 } }
以上的注释不仅仅是文档,还可以作为单元测试的用例运行,使用 cargo test 运行测试:
Doc-tests world_hello
running 2 tests
test src/compute.rs - compute::add_one (line 8) ... ok
test src/compute.rs - compute::add_two (line 22) ... ok
test result: ok. 2 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 1.00s
可以看到,文档中的测试用例被完美运行,而且输出中也明确提示了 Doc-tests world_hello,意味着这些测试的名字叫 Doc test 文档测试。
需要注意的是,你可能需要使用类如
world_hello::compute::add_one(arg)的完整路径来调用函数,因为测试是在另外一个独立的线程中运行的
造成 panic 的文档测试
文档测试中的用例还可以造成 panic:
#![allow(unused)] fn main() { /// # Panics /// /// The function panics if the second argument is zero. /// /// ```rust /// // panics on division by zero /// world_hello::compute::div(10, 0); /// ``` pub fn div(a: i32, b: i32) -> i32 { if b == 0 { panic!("Divide-by-zero error"); } a / b } }
以上测试运行后会 panic:
---- src/compute.rs - compute::div (line 38) stdout ----
Test executable failed (exit code 101).
stderr:
thread 'main' panicked at 'Divide-by-zero error', src/compute.rs:44:9
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
如果想要通过这种测试,可以添加 should_panic:
#![allow(unused)] fn main() { /// # Panics /// /// The function panics if the second argument is zero. /// /// ```rust,should_panic /// // panics on division by zero /// world_hello::compute::div(10, 0); /// ``` }
通过 should_panic,告诉 Rust 我们这个用例会导致 panic,这样测试用例就能顺利通过。
保留测试,隐藏文档
在某些时候,我们希望保留文档测试的功能,但是又要将某些测试用例的内容从文档中隐藏起来:
/// ``` /// # // 使用#开头的行会在文档中被隐藏起来,但是依然会在文档测试中运行 /// # fn try_main() -> Result<(), String> { /// let res = world_hello::compute::try_div(10, 0)?; /// # Ok(()) // returning from try_main /// # } /// # fn main() { /// # try_main().unwrap(); /// # /// # } /// ``` pub fn try_div(a: i32, b: i32) -> Result<i32, String> { if b == 0 { Err(String::from("Divide-by-zero")) } else { Ok(a / b) } }
以上文档注释中,我们使用 # 将不想让用户看到的内容隐藏起来,但是又不影响测试用例的运行,最终用户将只能看到那行没有隐藏的 let res = world_hello::compute::try_div(10, 0)?;:
文档注释中的代码跳转
Rust 在文档注释中还提供了一个非常强大的功能,那就是可以实现对外部项的链接:
跳转到标准库
#![allow(unused)] fn main() { /// `add_one` 返回一个[`Option`]类型 pub fn add_one(x: i32) -> Option<i32> { Some(x + 1) } }
此处的 [Option] 就是一个链接,指向了标准库中的 Option 枚举类型,有两种方式可以进行跳转:
- 在 IDE 中,使用
Command + 鼠标左键(macOS),CTRL + 鼠标左键(Windows) - 在文档中直接点击链接
再比如,还可以使用路径的方式跳转:
#![allow(unused)] fn main() { use std::sync::mpsc::Receiver; /// [`Receiver<T>`] [`std::future`]. /// /// [`std::future::Future`] [`Self::recv()`]. pub struct AsyncReceiver<T> { sender: Receiver<T>, } impl<T> AsyncReceiver<T> { pub async fn recv() -> T { unimplemented!() } } }
使用完整路径跳转到指定项
除了跳转到标准库,你还可以通过指定具体的路径跳转到自己代码或者其它库的指定项,例如在 lib.rs 中添加以下代码:
#![allow(unused)] fn main() { pub mod a { /// `add_one` 返回一个[`Option`]类型 /// 跳转到[`crate::MySpecialFormatter`] pub fn add_one(x: i32) -> Option<i32> { Some(x + 1) } } pub struct MySpecialFormatter; }
使用 crate::MySpecialFormatter 这种路径就可以实现跳转到 lib.rs 中定义的结构体上。
同名项的跳转
如果遇到同名项,可以使用标示类型的方式进行跳转:
#![allow(unused)] fn main() { /// 跳转到结构体 [`Foo`](struct@Foo) pub struct Bar; /// 跳转到同名函数 [`Foo`](fn@Foo) pub struct Foo {} /// 跳转到同名宏 [`foo!`] pub fn Foo() {} #[macro_export] macro_rules! foo { () => {} } }
文档搜索别名
Rust 文档支持搜索功能,我们可以为自己的类型定义几个别名,以实现更好的搜索展现,当别名命中时,搜索结果会被放在第一位:
#![allow(unused)] fn main() { #[doc(alias = "x")] #[doc(alias = "big")] pub struct BigX; #[doc(alias("y", "big"))] pub struct BigY; }
结果如下图所示:

一个综合例子
这个例子我们将重点应用几个知识点:
- 文档注释
- 一个项目可以包含两个包:二进制可执行包和
lib包(库包),它们的包根分别是src/main.rs和src/lib.rs - 在二进制包中引用
lib包 - 使用
pub use再导出 API,并观察文档
首先,使用 cargo new art 创建一个 Package art:
Created binary (application) `art` package
系统提示我们创建了一个二进制 Package,根据之前章节学过的内容,可以知道该 Package 包含一个同名的二进制包:包名为 art,包根为 src/main.rs,该包可以编译成二进制然后运行。
现在,在 src 目录下创建一个 lib.rs 文件,同样,根据之前学习的知识,创建该文件等于又创建了一个库类型的包,包名也是 art,包根为 src/lib.rs,该包是是库类型的,因此往往作为依赖库被引入。
将以下内容添加到 src/lib.rs 中:
#![allow(unused)] fn main() { //! # Art //! //! 未来的艺术建模库,现在的调色库 pub use self::kinds::PrimaryColor; pub use self::kinds::SecondaryColor; pub use self::utils::mix; pub mod kinds { //! 定义颜色的类型 /// 主色 pub enum PrimaryColor { Red, Yellow, Blue, } /// 副色 #[derive(Debug,PartialEq)] pub enum SecondaryColor { Orange, Green, Purple, } } pub mod utils { //! 实用工具,目前只实现了调色板 use crate::kinds::*; /// 将两种主色调成副色 /// ```rust /// use art::utils::mix; /// use art::kinds::{PrimaryColor,SecondaryColor}; /// assert!(matches!(mix(PrimaryColor::Yellow, PrimaryColor::Blue), SecondaryColor::Green)); /// ``` pub fn mix(c1: PrimaryColor, c2: PrimaryColor) -> SecondaryColor { SecondaryColor::Green } } }
在库包的包根 src/lib.rs 下,我们又定义了几个子模块,同时将子模块中的三个项通过 pub use 进行了再导出。
接着,将下面内容添加到 src/main.rs 中:
use art::kinds::PrimaryColor; use art::utils::mix; fn main() { let blue = PrimaryColor::Blue; let yellow = PrimaryColor::Yellow; println!("{:?}",mix(blue, yellow)); }
在二进制可执行包的包根 src/main.rs 下,我们引入了库包 art 中的模块项,同时使用 main 函数作为程序的入口,该二进制包可以使用 cargo run 运行:
Green
至此,库包完美提供了用于调色的 API,二进制包引入这些 API 完美的实现了调色并打印输出。
最后,再来看看文档长啥样:
总结
在 Rust 中,注释分为三个主要类型:代码注释、文档注释、包和模块注释,每个注释类型都拥有两种形式:行注释和块注释,熟练掌握包模块和注释的知识,非常有助于我们创建工程性更强的项目。
类型转换
Rust 是类型安全的语言,因此在 Rust 中做类型转换不是一件简单的事,这一章节我们将对 Rust 中的类型转换进行详尽讲解。
高能预警:本章节有些难,可以考虑学了进阶后回头再看
as转换
先来看一段代码:
fn main() { let a: i32 = 10; let b: u16 = 100; if a < b { println!("Ten is less than one hundred."); } }
能跟着这本书一直学习到这里,说明你对 Rust 已经有了一定的理解,那么一眼就能看出这段代码注定会报错,因为 a 和 b 拥有不同的类型,Rust 不允许两种不同的类型进行比较。
解决办法很简单,只要把 b 转换成 i32 类型即可,Rust 中内置了一些基本类型之间的转换,这里使用 as 操作符来完成: if a < (b as i32) {...}。那么为什么不把 a 转换成 u16 类型呢?
因为每个类型能表达的数据范围不同,如果把范围较大的类型转换成较小的类型,会造成错误,因此我们需要把范围较小的类型转换成较大的类型,来避免这些问题的发生。
使用类型转换需要小心,因为如果执行以下操作
300_i32 as i8,你将获得44这个值,而不是300,因为i8类型能表达的的最大值为2^7 - 1,使用以下代码可以查看i8的最大值:
#![allow(unused)] fn main() { let a = i8::MAX; println!("{}",a); }
下面列出了常用的转换形式:
fn main() { let a = 3.1 as i8; let b = 100_i8 as i32; let c = 'a' as u8; // 将字符'a'转换为整数,97 println!("{},{},{}",a,b,c) }
内存地址转换为指针
#![allow(unused)] fn main() { let mut values: [i32; 2] = [1, 2]; let p1: *mut i32 = values.as_mut_ptr(); let first_address = p1 as usize; // 将p1内存地址转换为一个整数 let second_address = first_address + 4; // 4 == std::mem::size_of::<i32>(),i32类型占用4个字节,因此将内存地址 + 4 let p2 = second_address as *mut i32; // 访问该地址指向的下一个整数p2 unsafe { *p2 += 1; } assert_eq!(values[1], 3); }
强制类型转换的边角知识
- 转换不具有传递性
就算
e as U1 as U2是合法的,也不能说明e as U2是合法的(e不能直接转换成U2)。
TryInto 转换
在一些场景中,使用 as 关键字会有比较大的限制。如果你想要在类型转换上拥有完全的控制而不依赖内置的转换,例如处理转换错误,那么可以使用 TryInto :
use std::convert::TryInto; fn main() { let a: u8 = 10; let b: u16 = 1500; let b_: u8 = b.try_into().unwrap(); if a < b_ { println!("Ten is less than one hundred."); } }
上面代码中引入了 std::convert::TryInto 特征,但是却没有使用它,可能有些同学会为此困惑,主要原因在于如果你要使用一个特征的方法,那么你需要引入该特征到当前的作用域中,我们在上面用到了 try_into 方法,因此需要引入对应的特征。但是 Rust 又提供了一个非常便利的办法,把最常用的标准库中的特征通过std::prelude模块提前引入到当前作用域中,其中包括了 std::convert::TryInto,你可以尝试删除第一行的代码 use ...,看看是否会报错。
try_into 会尝试进行一次转换,并返回一个 Result,此时就可以对其进行相应的错误处理。由于我们的例子只是为了快速测试,因此使用了 unwrap 方法,该方法在发现错误时,会直接调用 panic 导致程序的崩溃退出,在实际项目中,请不要这么使用,具体见panic部分。
最主要的是 try_into 转换会捕获大类型向小类型转换时导致的溢出错误:
fn main() { let b: i16 = 1500; let b_: u8 = match b.try_into() { Ok(b1) => b1, Err(e) => { println!("{:?}", e.to_string()); 0 } }; }
运行后输出如下 "out of range integral type conversion attempted",在这里我们程序捕获了错误,编译器告诉我们类型范围超出的转换是不被允许的,因为我们试图把 1500_i16 转换为 u8 类型,后者明显不足以承载这么大的值。
通用类型转换
虽然 as 和 TryInto 很强大,但是只能应用在数值类型上,可是 Rust 有如此多的类型,想要为这些类型实现转换,我们需要另谋出路,先来看看在一个笨办法,将一个结构体转换为另外一个结构体:
#![allow(unused)] fn main() { struct Foo { x: u32, y: u16, } struct Bar { a: u32, b: u16, } fn reinterpret(foo: Foo) -> Bar { let Foo { x, y } = foo; Bar { a: x, b: y } } }
简单粗暴,但是从另外一个角度来看,也挺啰嗦的,好在 Rust 为我们提供了更通用的方式来完成这个目的。
强制类型转换
在某些情况下,类型是可以进行隐式强制转换的,虽然这些转换弱化了 Rust 的类型系统,但是它们的存在是为了让 Rust 在大多数场景可以工作(说白了,帮助用户省事),而不是报各种类型上的编译错误。
首先,在匹配特征时,不会做任何强制转换(除了方法)。一个类型 T 可以强制转换为 U,不代表 impl T 可以强制转换为 impl U,例如下面的代码就无法通过编译检查:
trait Trait {} fn foo<X: Trait>(t: X) {} impl<'a> Trait for &'a i32 {} fn main() { let t: &mut i32 = &mut 0; foo(t); }
报错如下:
error[E0277]: the trait bound `&mut i32: Trait` is not satisfied
--> src/main.rs:9:9
|
9 | foo(t);
| ^ the trait `Trait` is not implemented for `&mut i32`
|
= help: the following implementations were found:
<&'a i32 as Trait>
= note: `Trait` is implemented for `&i32`, but not for `&mut i32`
&i32 实现了特征 Trait, &mut i32 可以转换为 &i32,但是 &mut i32 依然无法作为 Trait 来使用。
点操作符
方法调用的点操作符看起来简单,实际上非常不简单,它在调用时,会发生很多魔法般的类型转换,例如:自动引用、自动解引用,强制类型转换直到类型能匹配等。
假设有一个方法 foo,它有一个接收器(接收器就是 self、&self、&mut self 参数)。如果调用 value.foo(),编译器在调用 foo 之前,需要决定到底使用哪个 Self 类型来调用。现在假设 value 拥有类型 T。
再进一步,我们使用完全限定语法来进行准确的函数调用:
- 首先,编译器检查它是否可以直接调用
T::foo(value),称之为值方法调用 - 如果上一步调用无法完成(例如方法类型错误或者特征没有针对
Self进行实现,上文提到过特征不能进行强制转换),那么编译器会尝试增加自动引用,例如会尝试以下调用:<&T>::foo(value)和<&mut T>::foo(value),称之为引用方法调用 - 若上面两个方法依然不工作,编译器会试着解引用
T,然后再进行尝试。这里使用了Deref特征 —— 若T: Deref<Target = U>(T可以被解引用为U),那么编译器会使用U类型进行尝试,称之为解引用方法调用 - 若
T不能被解引用,且T是一个定长类型(在编译器类型长度是已知的),那么编译器也会尝试将T从定长类型转为不定长类型,例如将[i32; 2]转为[i32] - 若还是不行,那...没有那了,最后编译器大喊一声:汝欺我甚,不干了!
下面我们来用一个例子来解释上面的方法查找算法:
#![allow(unused)] fn main() { let array: Rc<Box<[T; 3]>> = ...; let first_entry = array[0]; }
array 数组的底层数据隐藏在了重重封锁之后,那么编译器如何使用 array[0] 这种数组原生访问语法通过重重封锁,准确的访问到数组中的第一个元素?
- 首先,
array[0]只是Index特征的语法糖:编译器会将array[0]转换为array.index(0)调用,当然在调用之前,编译器会先检查array是否实现了Index特征。 - 接着,编译器检查
Rc<Box<[T; 3]>>是否有实现Index特征,结果是否,不仅如此,&Rc<Box<[T; 3]>>与&mut Rc<Box<[T; 3]>>也没有实现。 - 上面的都不能工作,编译器开始对
Rc<Box<[T; 3]>>进行解引用,把它转变成Box<[T; 3]> - 此时继续对
Box<[T; 3]>进行上面的操作 :Box<[T; 3]>,&Box<[T; 3]>,和&mut Box<[T; 3]>都没有实现Index特征,所以编译器开始对Box<[T; 3]>进行解引用,然后我们得到了[T; 3] [T; 3]以及它的各种引用都没有实现Index索引(是不是很反直觉:D,在直觉中,数组都可以通过索引访问,实际上只有数组切片才可以!),它也不能再进行解引用,因此编译器只能祭出最后的大杀器:将定长转为不定长,因此[T; 3]被转换成[T],也就是数组切片,它实现了Index特征,因此最终我们可以通过index方法访问到对应的元素。
过程看起来很复杂,但是也还好,挺好理解,如果你现在不能彻底理解,也不要紧,等以后对 Rust 理解更深了,同时需要深入理解类型转换时,再来细细品读本章。
再来看看以下更复杂的例子:
#![allow(unused)] fn main() { fn do_stuff<T: Clone>(value: &T) { let cloned = value.clone(); } }
上面例子中 cloned 的类型是什么?首先编译器检查能不能进行值方法调用, value 的类型是 &T,同时 clone 方法的签名也是 &T : fn clone(&T) -> T,因此可以进行值方法调用,再加上编译器知道了 T 实现了 Clone,因此 cloned 的类型是 T。
如果 T: Clone 的特征约束被移除呢?
#![allow(unused)] fn main() { fn do_stuff<T>(value: &T) { let cloned = value.clone(); } }
首先,从直觉上来说,该方法会报错,因为 T 没有实现 Clone 特征,但是真实情况是什么呢?
我们先来推导一番。 首先通过值方法调用就不再可行,因为 T 没有实现 Clone 特征,也就无法调用 T 的 clone 方法。接着编译器尝试引用方法调用,此时 T 变成 &T,在这种情况下, clone 方法的签名如下: fn clone(&&T) -> &T,接着我们现在对 value 进行了引用。 编译器发现 &T 实现了 Clone 类型(所有的引用类型都可以被复制,因为其实就是复制一份地址),因此可以推出 cloned 也是 &T 类型。
最终,我们复制出一份引用指针,这很合理,因为值类型 T 没有实现 Clone,只能去复制一个指针了。
下面的例子也是自动引用生效的地方:
#![allow(unused)] fn main() { #[derive(Clone)] struct Container<T>(Arc<T>); fn clone_containers<T>(foo: &Container<i32>, bar: &Container<T>) { let foo_cloned = foo.clone(); let bar_cloned = bar.clone(); } }
推断下上面的 foo_cloned 和 bar_cloned 是什么类型?提示: 关键在 Container 的泛型参数,一个是 i32 的具体类型,一个是泛型类型,其中 i32 实现了 Clone,但是 T 并没有。
首先要复习一下复杂类型派生 Clone 的规则:一个复杂类型能否派生 Clone,需要它内部的所有子类型都能进行 Clone。因此 Container<T>(Arc<T>) 是否实现 Clone 的关键在于 T 类型是否实现了 Clone 特征。
上面代码中,Container<i32> 实现了 Clone 特征,因此编译器可以直接进行值方法调用,此时相当于直接调用 foo.clone,其中 clone 的函数签名是 fn clone(&T) -> T,由此可以看出 foo_cloned 的类型是 Container<i32>。
然而,bar_cloned 的类型却是 &Container<T>,这个不合理啊,明明我们为 Container<T> 派生了 Clone 特征,因此它也应该是 Container<T> 类型才对。万事皆有因,我们先来看下 derive 宏最终生成的代码大概是啥样的:
#![allow(unused)] fn main() { impl<T> Clone for Container<T> where T: Clone { fn clone(&self) -> Self { Self(Arc::clone(&self.0)) } } }
从上面代码可以看出,派生 Clone 能实现的根本是 T 实现了Clone特征:where T: Clone, 因此 Container<T> 就没有实现 Clone 特征。
编译器接着会去尝试引用方法调用,此时 &Container<T> 引用实现了 Clone,最终可以得出 bar_cloned 的类型是 &Container<T>。
当然,也可以为 Container<T> 手动实现 Clone 特征:
#![allow(unused)] fn main() { impl<T> Clone for Container<T> { fn clone(&self) -> Self { Self(Arc::clone(&self.0)) } } }
此时,编译器首次尝试值方法调用即可通过,因此 bar_cloned 的类型变成 Container<T>。
变形记(Transmutes)
前方危险,敬请绕行!
类型系统,你让开!我要自己转换这些类型,不成功便成仁!虽然本书大多是关于安全的内容,我还是希望你能仔细考虑避免使用本章讲到的内容。这是你在 Rust 中所能做到的真真正正、彻彻底底、最最可怕的非安全行为,在这里,所有的保护机制都形同虚设。
先让你看看深渊长什么样,开开眼,然后你再决定是否深入: mem::transmute<T, U> 将类型 T 直接转成类型 U,唯一的要求就是,这两个类型占用同样大小的字节数!我的天,这也算限制?这简直就是无底线的转换好吧?看看会导致什么问题:
- 首先也是最重要的,转换后创建一个任意类型的实例会造成无法想象的混乱,而且根本无法预测。不要把
3转换成bool类型,就算你根本不会去使用该bool类型,也不要去这样转换 - 变形后会有一个重载的返回类型,即使你没有指定返回类型,为了满足类型推导的需求,依然会产生千奇百怪的类型
- 将
&变形为&mut是未定义的行为- 这种转换永远都是未定义的
- 不,你不能这么做
- 不要多想,你没有那种幸运
- 变形为一个未指定生命周期的引用会导致无界生命周期
- 在复合类型之间互相变换时,你需要保证它们的排列布局是一模一样的!一旦不一样,那么字段就会得到不可预期的值,这也是未定义的行为,至于你会不会因此愤怒, WHO CARES ,你都用了变形了,老兄!
对于第 5 条,你该如何知道内存的排列布局是一样的呢?对于 repr(C) 类型和 repr(transparent) 类型来说,它们的布局是有着精确定义的。但是对于你自己的"普通却自信"的 Rust 类型 repr(Rust) 来说,它可不是有着精确定义的。甚至同一个泛型类型的不同实例都可以有不同的内存布局。 Vec<i32> 和 Vec<u32> 它们的字段可能有着相同的顺序,也可能没有。对于数据排列布局来说,什么能保证,什么不能保证目前还在 Rust 开发组的工作任务中呢。
你以为你之前凝视的是深渊吗?不,你凝视的只是深渊的大门。 mem::transmute_copy<T, U> 才是真正的深渊,它比之前的还要更加危险和不安全。它从 T 类型中拷贝出 U 类型所需的字节数,然后转换成 U。 mem::transmute 尚有大小检查,能保证两个数据的内存大小一致,现在这哥们干脆连这个也丢了,只不过 U 的尺寸若是比 T 大,会是一个未定义行为。
当然,你也可以通过裸指针转换和 unions (todo!)获得所有的这些功能,但是你将无法获得任何编译提示或者检查。裸指针转换和 unions 也不是魔法,无法逃避上面说的规则。
transmute 虽然危险,但作为一本工具书,知识当然要全面,下面列举两个有用的 transmute 应用场景 :)。
- 将裸指针变成函数指针:
#![allow(unused)] fn main() { fn foo() -> i32 { 0 } let pointer = foo as *const (); let function = unsafe { // 将裸指针转换为函数指针 std::mem::transmute::<*const (), fn() -> i32>(pointer) }; assert_eq!(function(), 0); }
- 延长生命周期,或者缩短一个静态生命周期寿命:
#![allow(unused)] fn main() { struct R<'a>(&'a i32); // 将 'b 生命周期延长至 'static 生命周期 unsafe fn extend_lifetime<'b>(r: R<'b>) -> R<'static> { std::mem::transmute::<R<'b>, R<'static>>(r) } // 将 'static 生命周期缩短至 'c 生命周期 unsafe fn shorten_invariant_lifetime<'b, 'c>(r: &'b mut R<'static>) -> &'b mut R<'c> { std::mem::transmute::<&'b mut R<'static>, &'b mut R<'c>>(r) } }
以上例子非常先进!但是是非常不安全的 Rust 行为!
格式化输出
提到格式化输出,可能很多人立刻就想到 "{}",但是 Rust 能做到的远比这个多的多,本章节我们将深入讲解格式化输出的各个方面。
满分初印象
先来一段代码,看看格式化输出的初印象:
#![allow(unused)] fn main() { println!("Hello"); // => "Hello" println!("Hello, {}!", "world"); // => "Hello, world!" println!("The number is {}", 1); // => "The number is 1" println!("{:?}", (3, 4)); // => "(3, 4)" println!("{value}", value=4); // => "4" println!("{} {}", 1, 2); // => "1 2" println!("{:04}", 42); // => "0042" with leading zeros }
可以看到 println! 宏接受的是可变参数,第一个参数是一个字符串常量,它表示最终输出字符串的格式,包含其中形如 {} 的符号是占位符,会被 println! 后面的参数依次替换。
print!,println!,format!
它们是 Rust 中用来格式化输出的三大金刚,用途如下:
print!将格式化文本输出到标准输出,不带换行符println!同上,但是在行的末尾添加换行符format!将格式化文本输出到String字符串
在实际项目中,最常用的是 println! 及 format!,前者常用来调试输出,后者常用来生成格式化的字符串:
fn main() { let s = "hello"; println!("{}, world", s); let s1 = format!("{}, world", s); print!("{}", s1); print!("{}\n", "!"); }
其中,s1 是通过 format! 生成的 String 字符串,最终输出如下:
hello, world
hello, world!
eprint!,eprintln!
除了三大金刚外,还有两大护法,使用方式跟 print!,println! 很像,但是它们输出到标准错误输出:
#![allow(unused)] fn main() { eprintln!("Error: Could not complete task") }
它们仅应该被用于输出错误信息和进度信息,其它场景都应该使用 print! 系列。
{} 与 {:?}
与其它语言常用的 %d,%s 不同,Rust 特立独行地选择了 {} 作为格式化占位符(说到这个,有点想吐槽下,Rust 中自创的概念其实还挺多的,真不知道该夸奖还是该吐槽-,-),事实证明,这种选择非常正确,它帮助用户减少了很多使用成本,你无需再为特定的类型选择特定的占位符,统一用 {} 来替代即可,剩下的类型推导等细节只要交给 Rust 去做。
与 {} 类似,{:?} 也是占位符:
{}适用于实现了std::fmt::Display特征的类型,用来以更优雅、更友好的方式格式化文本,例如展示给用户{:?}适用于实现了std::fmt::Debug特征的类型,用于调试场景
其实两者的选择很简单,当你在写代码需要调试时,使用 {:?},剩下的场景,选择 {}。
Debug 特征
事实上,为了方便我们调试,大多数 Rust 类型都实现了 Debug 特征或者支持派生该特征:
#[derive(Debug)] struct Person { name: String, age: u8 } fn main() { let i = 3.1415926; let s = String::from("hello"); let v = vec![1, 2, 3]; let p = Person{name: "sunface".to_string(), age: 18}; println!("{:?}, {:?}, {:?}, {:?}", i, s, v, p); }
对于数值、字符串、数组,可以直接使用 {:?} 进行输出,但是对于结构体,需要派生Debug特征后,才能进行输出,总之很简单。
Display 特征
与大部分类型实现了 Debug 不同,实现了 Display 特征的 Rust 类型并没有那么多,往往需要我们自定义想要的格式化方式:
#![allow(unused)] fn main() { let i = 3.1415926; let s = String::from("hello"); let v = vec![1, 2, 3]; let p = Person { name: "sunface".to_string(), age: 18, }; println!("{}, {}, {}, {}", i, s, v, p); }
运行后可以看到 v 和 p 都无法通过编译,因为没有实现 Display 特征,但是你又不能像派生 Debug 一般派生 Display,只能另寻他法:
- 使用
{:?}或{:#?} - 为自定义类型实现
Display特征 - 使用
newtype为外部类型实现Display特征
下面来一一看看这三种方式。
{:#?}
{:#?} 与 {:?} 几乎一样,唯一的区别在于它能更优美地输出内容:
// {:?}
[1, 2, 3], Person { name: "sunface", age: 18 }
// {:#?}
[
1,
2,
3,
], Person {
name: "sunface",
}
因此对于 Display 不支持的类型,可以考虑使用 {:#?} 进行格式化,虽然理论上它更适合进行调试输出。
为自定义类型实现 Display 特征
如果你的类型是定义在当前作用域中的,那么可以为其实现 Display 特征,即可用于格式化输出:
struct Person { name: String, age: u8, } use std::fmt; impl fmt::Display for Person { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "大佬在上,请受我一拜,小弟姓名{},年芳{},家里无田又无车,生活苦哈哈", self.name, self.age ) } } fn main() { let p = Person { name: "sunface".to_string(), age: 18, }; println!("{}", p); }
如上所示,只要实现 Display 特征中的 fmt 方法,即可为自定义结构体 Person 添加自定义输出:
大佬在上,请受我一拜,小弟姓名sunface,年芳18,家里无田又无车,生活苦哈哈
为外部类型实现 Display 特征
在 Rust 中,无法直接为外部类型实现外部特征,但是可以使用newtype解决此问题:
struct Array(Vec<i32>); use std::fmt; impl fmt::Display for Array { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "数组是:{:?}", self.0) } } fn main() { let arr = Array(vec![1, 2, 3]); println!("{}", arr); }
Array 就是我们的 newtype,它将想要格式化输出的 Vec 包裹在内,最后只要为 Array 实现 Display 特征,即可进行格式化输出:
数组是:[1, 2, 3]
至此,关于 {} 与 {:?} 的内容已介绍完毕,下面让我们正式开始格式化输出的旅程。
位置参数
除了按照依次顺序使用值去替换占位符之外,还能让指定位置的参数去替换某个占位符,例如 {1},表示用第二个参数替换该占位符(索引从 0 开始):
fn main() { println!("{}{}", 1, 2); // =>"12" println!("{1}{0}", 1, 2); // =>"21" // => Alice, this is Bob. Bob, this is Alice println!("{0}, this is {1}. {1}, this is {0}", "Alice", "Bob"); println!("{1}{}{0}{}", 1, 2); // => 2112 }
具名参数
除了像上面那样指定位置外,我们还可以为参数指定名称:
fn main() { println!("{argument}", argument = "test"); // => "test" println!("{name} {}", 1, name = 2); // => "2 1" println!("{a} {c} {b}", a = "a", b = 'b', c = 3); // => "a 3 b" }
需要注意的是:带名称的参数必须放在不带名称参数的后面,例如下面代码将报错:
#![allow(unused)] fn main() { println!("{abc} {1}", abc = "def", 2); }
#![allow(unused)] fn main() { error: positional arguments cannot follow named arguments --> src/main.rs:4:36 | 4 | println!("{abc} {1}", abc = "def", 2); | ----- ^ positional arguments must be before named arguments | | | named argument }
格式化参数
格式化输出,意味着对输出格式会有更多的要求,例如只输出浮点数的小数点后两位:
fn main() { let v = 3.1415926; // Display => 3.14 println!("{:.2}", v); // Debug => 3.14 println!("{:.2?}", v); }
上面代码只输出小数点后两位。同时我们还展示了 {} 和 {:?} 的用法,后面如无特殊区别,就只针对 {} 提供格式化参数说明。
接下来,让我们一起来看看 Rust 中有哪些格式化参数。
宽度
宽度用来指示输出目标的长度,如果长度不够,则进行填充和对齐:
字符串填充
字符串格式化默认使用空格进行填充,并且进行左对齐。
fn main() { //----------------------------------- // 以下全部输出 "Hello x !" // 为"x"后面填充空格,补齐宽度5 println!("Hello {:5}!", "x"); // 使用参数5来指定宽度 println!("Hello {:1$}!", "x", 5); // 使用x作为占位符输出内容,同时使用5作为宽度 println!("Hello {1:0$}!", 5, "x"); // 使用有名称的参数作为宽度 println!("Hello {:width$}!", "x", width = 5); //----------------------------------- // 使用参数5为参数x指定宽度,同时在结尾输出参数5 => Hello x !5 println!("Hello {:1$}!{}", "x", 5); }
数字填充:符号和 0
数字格式化默认也是使用空格进行填充,但与字符串左对齐不同的是,数字是右对齐。
fn main() { // 宽度是5 => Hello 5! println!("Hello {:5}!", 5); // 显式的输出正号 => Hello +5! println!("Hello {:+}!", 5); // 宽度5,使用0进行填充 => Hello 00005! println!("Hello {:05}!", 5); // 负号也要占用一位宽度 => Hello -0005! println!("Hello {:05}!", -5); }
对齐
fn main() { // 以下全部都会补齐5个字符的长度 // 左对齐 => Hello x ! println!("Hello {:<5}!", "x"); // 右对齐 => Hello x! println!("Hello {:>5}!", "x"); // 居中对齐 => Hello x ! println!("Hello {:^5}!", "x"); // 对齐并使用指定符号填充 => Hello x&&&&! // 指定符号填充的前提条件是必须有对齐字符 println!("Hello {:&<5}!", "x"); }
精度
精度可以用于控制浮点数的精度或者字符串的长度
fn main() { let v = 3.1415926; // 保留小数点后两位 => 3.14 println!("{:.2}", v); // 带符号保留小数点后两位 => +3.14 println!("{:+.2}", v); // 不带小数 => 3 println!("{:.0}", v); // 通过参数来设定精度 => 3.1416,相当于{:.4} println!("{:.1$}", v, 4); let s = "hi我是Sunface孙飞"; // 保留字符串前三个字符 => hi我 println!("{:.3}", s); // {:.*}接收两个参数,第一个是精度,第二个是被格式化的值 => Hello abc! println!("Hello {:.*}!", 3, "abcdefg"); }
进制
可以使用 # 号来控制数字的进制输出:
#b, 二进制#o, 八进制#x, 小写十六进制#X, 大写十六进制x, 不带前缀的小写十六进制
fn main() { // 二进制 => 0b11011! println!("{:#b}!", 27); // 八进制 => 0o33! println!("{:#o}!", 27); // 十进制 => 27! println!("{}!", 27); // 小写十六进制 => 0x1b! println!("{:#x}!", 27); // 大写十六进制 => 0x1B! println!("{:#X}!", 27); // 不带前缀的十六进制 => 1b! println!("{:x}!", 27); // 使用0填充二进制,宽度为10 => 0b00011011! println!("{:#010b}!", 27); }
指数
fn main() { println!("{:2e}", 1000000000); // => 1e9 println!("{:2E}", 1000000000); // => 1E9 }
指针地址
#![allow(unused)] fn main() { let v= vec![1, 2, 3]; println!("{:p}", v.as_ptr()) // => 0x600002324050 }
转义
有时需要输出 {和},但这两个字符是特殊字符,需要进行转义:
fn main() { // "{{" 转义为 '{' "}}" 转义为 '}' "\"" 转义为 '"' // => Hello "{World}" println!(" Hello \"{{World}}\" "); // 下面代码会报错,因为占位符{}只有一个右括号},左括号被转义成字符串的内容 // println!(" {{ Hello } "); // 也不可使用 '\' 来转义 "{}" // println!(" \{ Hello \} ") }
在格式化字符串时捕获环境中的值(Rust 1.58 新增)
在以前,想要输出一个函数的返回值,你需要这么做:
fn get_person() -> String { String::from("sunface") } fn main() { let p = get_person(); println!("Hello, {}!", p); // implicit position println!("Hello, {0}!", p); // explicit index println!("Hello, {person}!", person = p); }
问题倒也不大,但是一旦格式化字符串长了后,就会非常冗余,而在 1.58 后,我们可以这么写:
fn get_person() -> String { String::from("sunface") } fn main() { let person = get_person(); println!("Hello, {person}!"); }
是不是清晰、简洁了很多?甚至还可以将环境中的值用于格式化参数:
#![allow(unused)] fn main() { let (width, precision) = get_format(); for (name, score) in get_scores() { println!("{name}: {score:width$.precision$}"); } }
但也有局限,它只能捕获普通的变量,对于更复杂的类型(例如表达式),可以先将它赋值给一个变量或使用以前的 name = expression 形式的格式化参数。
目前除了 panic! 外,其它接收格式化参数的宏,都可以使用新的特性。对于 panic! 而言,如果还在使用 2015版本 或 2018版本,那 panic!("{ident}") 依然会被当成 正常的字符串来处理,同时编译器会给予 warn 提示。而对于 2021版本 ,则可以正常使用:
fn get_person() -> String { String::from("sunface") } fn main() { let person = get_person(); panic!("Hello, {person}!"); }
输出:
thread 'main' panicked at 'Hello, sunface!', src/main.rs:6:5
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
构建一个简单命令行程序
在前往更高的山峰前,我们应该驻足欣赏下身后的风景,虽然是半览众山不咋小,但总比身在此山中无法窥全貌要强一丢丢。
在本章中,我们将一起构建一个命令行程序,目标是尽可能帮大家融会贯通之前的学到的知识。
linux 系统中的 grep 命令很强大,可以完成各种文件搜索任务,我们肯定做不了那么强大,但是假冒一个伪劣的版本还是可以的,它将从命令行参数中读取指定的文件名和字符串,然后在相应的文件中找到包含该字符串的内容,最终打印出来。
这里推荐一位大神写的知名 Rust 项目 ripgrep ,绝对是
grep真正的高替品,值得学习和使用
实现基本功能
无论功能设计的再怎么花里胡哨,对于一个文件查找命令而言,首先得指定文件和待查找的字符串,它们需要用户从命令行给予输入,然后我们在程序内进行读取。
接收命令行参数
国际惯例,先创建一个新的项目 minigrep ,该名字充分体现了我们的自信:就是不如 grep。
cargo new minigrep
Created binary (application) `minigrep` project
$ cd minigrep
首先来思考下,如果要传入文件路径和待搜索的字符串,那这个命令该长啥样,我觉得大概率是这样:
cargo run -- searchstring example-filename.txt
-- 告诉 cargo 后面的参数是给我们的程序使用的,而不是给 cargo 自己使用,例如 -- 前的 run 就是给它用的。
接下来就是在程序中读取传入的参数,这个很简单,下面代码就可以:
// in main.rs use std::env; fn main() { let args: Vec<String> = env::args().collect(); dbg!(args); }
首先通过 use 引入标准库中的 env 包,然后 env::args 方法会读取并分析传入的命令行参数,最终通过 collect 方法输出一个集合类型 Vector。
可能有同学疑惑,为啥不直接引入 args ,例如 use std::env::args ,这样就无需 env::args 来繁琐调用,直接args.collect() 即可。原因很简单,args 方法只会使用一次,啰嗦就啰嗦点吧,把相同的好名字让给 let args.. 这位大哥不好吗?毕竟人家要出场多次的。
不可信的输入
所有的用户输入都不可信!不可信!不可信!
重要的话说三遍,我们的命令行程序也是,用户会输入什么你根本就不知道,例如他输入了一个非 Unicode 字符,你能阻止吗?显然不能,但是这种输入会直接让我们的程序崩溃!
原因是当传入的命令行参数包含非 Unicode 字符时,
std::env::args会直接崩溃,如果有这种特殊需求,建议大家使用std::env::args_os,该方法产生的数组将包含OsString类型,而不是之前的String类型,前者对于非 Unicode 字符会有更好的处理。至于为啥我们不用,两个理由,你信哪个:1. 用户爱输入啥输入啥,反正崩溃了,他就知道自己错了 2.
args_os会引入额外的跨平台复杂性
collect 方法其实并不是std::env包提供的,而是迭代器自带的方法(env::args() 会返回一个迭代器),它会将迭代器消费后转换成我们想要的集合类型,关于迭代器和 collect 的具体介绍,请参加这里。
最后,代码中使用 dbg! 宏来输出读取到的数组内容,来看看长啥样:
$ cargo run
Compiling minigrep v0.1.0 (file:///projects/minigrep)
Finished dev [unoptimized + debuginfo] target(s) in 0.61s
Running `target/debug/minigrep`
[src/main.rs:5] args = [
"target/debug/minigrep",
]
$ cargo run -- needle haystack
Compiling minigrep v0.1.0 (file:///projects/minigrep)
Finished dev [unoptimized + debuginfo] target(s) in 1.57s
Running `target/debug/minigrep needle haystack`
[src/main.rs:5] args = [
"target/debug/minigrep",
"needle",
"haystack",
]
上面两个版本分别是无参数和两个参数,其中无参数版本实际上也会读取到一个字符串,仔细看,是不是长得很像我们的程序名,Bingo! env::args 读取到的参数中第一个就是程序的可执行路径名。
存储读取到的参数
在编程中,给予清晰合理的变量名是一项基本功,咱总不能到处都是 args[1] 、args[2] 这样的糟糕代码吧。
因此我们需要两个变量来存储文件路径和待搜索的字符串:
use std::env; fn main() { let args: Vec<String> = env::args().collect(); let query = &args[1]; let file_path = &args[2]; println!("Searching for {}", query); println!("In file {}", file_path); }
很简单的代码,来运行下:
$ cargo run -- test sample.txt
Compiling minigrep v0.1.0 (file:///projects/minigrep)
Finished dev [unoptimized + debuginfo] target(s) in 0.0s
Running `target/debug/minigrep test sample.txt`
Searching for test
In file sample.txt
输出结果很清晰的说明了我们的目标:在文件 sample.txt 中搜索包含 test 字符串的内容。
事实上,就算作为一个简单的程序,它也太过于简单了,例如用户不提供任何参数怎么办?因此,错误处理显然是不可少的,但是在添加之前,先来看看如何读取文件内容。
文件读取
既然读取文件,那么首先我们需要创建一个文件并给予一些内容,来首诗歌如何?"我啥也不是,你呢?"
I'm nobody! Who are you?
我啥也不是,你呢?
Are you nobody, too?
牛逼如你也是无名之辈吗?
Then there's a pair of us - don't tell!
那我们就是天生一对,嘘!别说话!
They'd banish us, you know.
你知道,我们不属于这里。
How dreary to be somebody!
因为这里属于没劲的大人物!
How public, like a frog
他们就像青蛙一样呱噪,
To tell your name the livelong day
成天将自己的大名
To an admiring bog!
传遍整个无聊的沼泽!
在项目根目录创建 poem.txt 文件,并写入如上的优美诗歌(可能翻译的很烂,别打我,哈哈,事实上大家写入英文内容就够了)。
接下来修改 main.rs 来读取文件内容:
use std::env; use std::fs; fn main() { // --省略之前的内容-- println!("In file {}", file_path); let contents = fs::read_to_string(file_path) .expect("Should have been able to read the file"); println!("With text:\n{contents}"); }
首先,通过 use std::fs 引入文件操作包,然后通过 fs::read_to_string 读取指定的文件内容,最后返回的 contents 是 std::io::Result<String> 类型。
运行下试试,这里无需输入第二个参数,因为我们还没有实现查询功能:
$ cargo run -- the poem.txt
Compiling minigrep v0.1.0 (file:///projects/minigrep)
Finished dev [unoptimized + debuginfo] target(s) in 0.0s
Running `target/debug/minigrep the poem.txt`
Searching for the
In file poem.txt
With text:
I'm nobody! Who are you?
Are you nobody, too?
Then there's a pair of us - don't tell!
They'd banish us, you know.
How dreary to be somebody!
How public, like a frog
To tell your name the livelong day
To an admiring bog!
完美,虽然代码还有很多瑕疵,例如所有内容都在 main 函数,这个不符合软件工程,没有错误处理,功能不完善等。不过没关系,万事开头难,好歹我们成功迈开了第一步。
好了,是时候重构赚波 KPI 了,读者:are you serious? 这就开始重构了?
增加模块化和错误处理
但凡稍微没那么糟糕的程序,都应该具有代码模块化和错误处理,不然连玩具都谈不上。
梳理我们的代码和目标后,可以整理出大致四个改进点:
- 单一且庞大的函数。对于
minigrep程序而言,main函数当前执行两个任务:解析命令行参数和读取文件。但随着代码的增加,main函数承载的功能也将快速增加。从软件工程角度来看,一个函数具有的功能越多,越是难以阅读和维护。因此最好的办法是将大的函数拆分成更小的功能单元。 - 配置变量散乱在各处。还有一点要考虑的是,当前
main函数中的变量都是独立存在的,这些变量很可能被整个程序所访问,在这个背景下,独立的变量越多,越是难以维护,因此我们还可以将这些用于配置的变量整合到一个结构体中。 - 细化错误提示。 目前的实现中,我们使用
expect方法来输出文件读取失败时的错误信息,这个没问题,但是无论任何情况下,都只输出Should have been able to read the file这条错误提示信息,显然是有问题的,毕竟文件不存在、无权限等等都是可能的错误,一条大一统的消息无法给予用户更多的提示。 - 使用错误而不是异常。 假如用户不给任何命令行参数,那我们的程序显然会无情崩溃,原因很简单:
index out of bounds,一个数组访问越界的panic,但问题来了,用户能看懂吗?甚至于未来接收的维护者能看懂吗?因此需要增加合适的错误处理代码,来给予使用者给详细友善的提示。还有就是需要在一个统一的位置来处理所有错误,利人利己!
分离 main 函数
关于如何处理庞大的 main 函数,Rust 社区给出了统一的指导方案:
- 将程序分割为
main.rs和lib.rs,并将程序的逻辑代码移动到后者内 - 命令行解析属于非常基础的功能,严格来说不算是逻辑代码的一部分,因此还可以放在
main.rs中
按照这个方案,将我们的代码重新梳理后,可以得出 main 函数应该包含的功能:
- 解析命令行参数
- 初始化其它配置
- 调用
lib.rs中的run函数,以启动逻辑代码的运行 - 如果
run返回一个错误,需要对该错误进行处理
这个方案有一个很优雅的名字: 关注点分离(Separation of Concerns)。简而言之,main.rs 负责启动程序,lib.rs 负责逻辑代码的运行。从测试的角度而言,这种分离也非常合理: lib.rs 中的主体逻辑代码可以得到简单且充分的测试,至于 main.rs ?确实没办法针对其编写额外的测试代码,但是它的代码也很少啊,很容易就能保证它的正确性。
关于如何在 Rust 中编写测试代码,请参见如下章节:https://course.rs/test/intro.html
分离命令行解析
根据之前的分析,我们需要将命令行解析的代码分离到一个单独的函数,然后将该函数放置在 main.rs 中:
// in main.rs fn main() { let args: Vec<String> = env::args().collect(); let (query, file_path) = parse_config(&args); // --省略-- } fn parse_config(args: &[String]) -> (&str, &str) { let query = &args[1]; let file_path = &args[2]; (query, file_path) }
经过分离后,之前的设计目标完美达成,即精简了 main 函数,又将配置相关的代码放在了 main.rs 文件里。
看起来貌似是杀鸡用了牛刀,但是重构就是这样,一步一步,踏踏实实的前行,否则未来代码多一些后,你岂不是还要再重来一次重构?因此打好项目的基础是非常重要的!
聚合配置变量
前文提到,配置变量并不适合分散的到处都是,因此使用一个结构体来统一存放是非常好的选择,这样修改后,后续的使用以及未来的代码维护都将更加简单明了。
fn main() { let args: Vec<String> = env::args().collect(); let config = parse_config(&args); println!("Searching for {}", config.query); println!("In file {}", config.file_path); let contents = fs::read_to_string(config.file_path) .expect("Should have been able to read the file"); // --snip-- } struct Config { query: String, file_path: String, } fn parse_config(args: &[String]) -> Config { let query = args[1].clone(); let file_path = args[2].clone(); Config { query, file_path } }
值得注意的是,Config 中存储的并不是 &str 这样的引用类型,而是一个 String 字符串,也就是 Config 并没有去借用外部的字符串,而是拥有内部字符串的所有权。clone 方法的使用也可以佐证这一点。大家可以尝试不用 clone 方法,看看该如何解决相关的报错 :D
clone的得与失在上面的代码中,除了使用
clone,还有其它办法来达成同样的目的,但clone无疑是最简单的方法:直接完整的复制目标数据,无需被所有权、借用等问题所困扰,但是它也有其缺点,那就是有一定的性能损耗。因此是否使用
clone更多是一种性能上的权衡,对于上面的使用而言,由于是配置的初始化,因此整个程序只需要执行一次,性能损耗几乎是可以忽略不计的。总之,判断是否使用
clone:
- 是否严肃的项目,玩具项目直接用
clone就行,简单不好吗?- 要看所在的代码路径是否是热点路径(hot path),例如执行次数较多的显然就是热点路径,热点路径就值得去使用性能更好的实现方式
好了,言归正传,从 C 语言过来的同学可能会觉得上面的代码已经很棒了,但是从 OO 语言角度来说,还差了那么一点意思。
下面我们试着来优化下,通过构造函数来初始化一个 Config 实例,而不是直接通过函数返回实例,典型的,标准库中的 String::new 函数就是一个范例。
fn main() { let args: Vec<String> = env::args().collect(); let config = Config::new(&args); // --snip-- } // --snip-- impl Config { fn new(args: &[String]) -> Config { let query = args[1].clone(); let file_path = args[2].clone(); Config { query, file_path } } }
修改后,类似 String::new 的调用,我们可以通过 Config::new 来创建一个实例,看起来代码是不是更有那味儿了 :)
错误处理
回顾一下,如果用户不输入任何命令行参数,我们的程序会怎么样?
$ cargo run
Compiling minigrep v0.1.0 (file:///projects/minigrep)
Finished dev [unoptimized + debuginfo] target(s) in 0.0s
Running `target/debug/minigrep`
thread 'main' panicked at 'index out of bounds: the len is 1 but the index is 1', src/main.rs:27:21
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
结果喜闻乐见,由于 args 数组没有任何元素,因此通过索引访问时,会直接报出数组访问越界的 panic。
报错信息对于开发者会很明确,但是对于使用者而言,就相当难理解了,下面一起来解决它。
改进报错信息
还记得在错误处理章节,我们提到过 panic 的两种用法: 被动触发和主动调用嘛?上面代码的出现方式很明显是被动触发,这种报错信息是不可控的,下面我们先改成主动调用的方式:
#![allow(unused)] fn main() { // in main.rs // --snip-- fn new(args: &[String]) -> Config { if args.len() < 3 { panic!("not enough arguments"); } // --snip-- }
目的很明确,一旦传入的参数数组长度小于 3,则报错并让程序崩溃推出,这样后续的数组访问就不会再越界了。
$ cargo run
Compiling minigrep v0.1.0 (file:///projects/minigrep)
Finished dev [unoptimized + debuginfo] target(s) in 0.0s
Running `target/debug/minigrep`
thread 'main' panicked at 'not enough arguments', src/main.rs:26:13
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
不错,用户看到了更为明确的提示,但是还是有一大堆 debug 输出,这些我们其实是不想让用户看到的。这么看来,想要输出对用户友好的信息, panic 是不太适合的,它更适合告知开发者,哪里出现了问题。
返回 Result 来替代直接 panic
那只能祭出之前学过的错误处理大法了,也就是返回一个 Result:成功时包含 Config 实例,失败时包含一条错误信息。
有一点需要额外注意下,从代码惯例的角度出发,new 往往不会失败,毕竟新建一个实例没道理失败,对不?因此修改为 build 会更加合适。
#![allow(unused)] fn main() { impl Config { fn build(args: &[String]) -> Result<Config, &'static str> { if args.len() < 3 { return Err("not enough arguments"); } let query = args[1].clone(); let file_path = args[2].clone(); Ok(Config { query, file_path }) } } }
这里的 Result 可能包含一个 Config 实例,也可能包含一条错误信息 &static str,不熟悉这种字符串类型的同学可以回头看看字符串章节,代码中的字符串字面量都是该类型,且拥有 'static 生命周期。
处理返回的 Result
接下来就是在调用 build 函数时,对返回的 Result 进行处理了,目的就是给出准确且友好的报错提示, 为了让大家更好的回顾我们修改过的内容,这里给出整体代码:
use std::env; use std::fs; use std::process; fn main() { let args: Vec<String> = env::args().collect(); // 对 build 返回的 `Result` 进行处理 let config = Config::build(&args).unwrap_or_else(|err| { println!("Problem parsing arguments: {err}"); process::exit(1); }); println!("Searching for {}", config.query); println!("In file {}", config.file_path); let contents = fs::read_to_string(config.file_path) .expect("Should have been able to read the file"); println!("With text:\n{contents}"); } struct Config { query: String, file_path: String, } impl Config { fn build(args: &[String]) -> Result<Config, &'static str> { if args.len() < 3 { return Err("not enough arguments"); } let query = args[1].clone(); let file_path = args[2].clone(); Ok(Config { query, file_path }) } }
上面代码有几点值得注意:
- 当
Result包含错误时,我们不再调用panic让程序崩溃,而是通过process::exit(1)来终结进程,其中1是一个信号值(事实上非 0 值都可以),通知调用我们程序的进程,程序是因为错误而退出的。 unwrap_or_else是定义在Result<T,E>上的常用方法,如果Result是Ok,那该方法就类似unwrap:返回Ok内部的值;如果是Err,就调用闭包中的自定义代码对错误进行进一步处理
综上可知,config 变量的值是一个 Config 实例,而 unwrap_or_else 闭包中的 err 参数,它的类型是 'static str,值是 "not enough arguments" 那个字符串字面量。
运行后,可以看到以下输出:
$ cargo run
Compiling minigrep v0.1.0 (file:///projects/minigrep)
Finished dev [unoptimized + debuginfo] target(s) in 0.48s
Running `target/debug/minigrep`
Problem parsing arguments: not enough arguments
终于,我们得到了自己想要的输出:既告知了用户为何报错,又消除了多余的 debug 信息,非常棒。可能有用户疑惑,cargo run 底下还有一大堆 debug 信息呢,实际上,这是 cargo run 自带的,大家可以试试编译成二进制可执行文件后再调用,会是什么效果。
分离主体逻辑
接下来可以继续精简 main 函数,那就是将主体逻辑( 例如业务逻辑 )从 main 中分离出去,这样 main 函数就保留主流程调用,非常简洁。
// in main.rs fn main() { let args: Vec<String> = env::args().collect(); let config = Config::build(&args).unwrap_or_else(|err| { println!("Problem parsing arguments: {err}"); process::exit(1); }); println!("Searching for {}", config.query); println!("In file {}", config.file_path); run(config); } fn run(config: Config) { let contents = fs::read_to_string(config.file_path) .expect("Should have been able to read the file"); println!("With text:\n{contents}"); } // --snip--
如上所示,main 函数仅保留主流程各个环节的调用,一眼看过去非常简洁清晰。
继续之前,先请大家仔细看看 run 函数,你们觉得还缺少什么?提示:参考 build 函数的改进过程。
使用 ? 和特征对象来返回错误
答案就是 run 函数没有错误处理,因为在文章开头我们提到过,错误处理最好统一在一个地方完成,这样极其有利于后续的代码维护。
#![allow(unused)] fn main() { //in main.rs use std::error::Error; // --snip-- fn run(config: Config) -> Result<(), Box<dyn Error>> { let contents = fs::read_to_string(config.file_path)?; println!("With text:\n{contents}"); Ok(()) } }
值得注意的是这里的 Result<(), Box<dyn Error>> 返回类型,首先我们的程序无需返回任何值,但是为了满足 Result<T,E> 的要求,因此使用了 Ok(()) 返回一个单元类型 ()。
最重要的是 Box<dyn Error>, 如果按照顺序学到这里,大家应该知道这是一个Error 的特征对象(为了使用 Error,我们通过 use std::error::Error; 进行了引入),它表示函数返回一个类型,该类型实现了 Error 特征,这样我们就无需指定具体的错误类型,否则你还需要查看 fs::read_to_string 返回的错误类型,然后复制到我们的 run 函数返回中,这么做一个是麻烦,最主要的是,一旦这么做,意味着我们无法在上层调用时统一处理错误,但是 Box<dyn Error> 不同,其它函数也可以返回这个特征对象,然后调用者就可以使用统一的方式来处理不同函数返回的 Box<dyn Error>。
明白了 Box<dyn Error> 的重要战略地位,接下来大家分析下,fs::read_to_string 返回的具体错误类型是怎么被转化为 Box<dyn Error> 的?其实原因在之前章节都有讲过,这里就不直接给出答案了,参见 ?-传播界的大明星。
运行代码看看效果:
$ cargo run the poem.txt
Compiling minigrep v0.1.0 (file:///projects/minigrep)
warning: unused `Result` that must be used
--> src/main.rs:19:5
|
19 | run(config);
| ^^^^^^^^^^^^
|
= note: `#[warn(unused_must_use)]` on by default
= note: this `Result` may be an `Err` variant, which should be handled
warning: `minigrep` (bin "minigrep") generated 1 warning
Finished dev [unoptimized + debuginfo] target(s) in 0.71s
Running `target/debug/minigrep the poem.txt`
Searching for the
In file poem.txt
With text:
I'm nobody! Who are you?
Are you nobody, too?
Then there's a pair of us - don't tell!
They'd banish us, you know.
How dreary to be somebody!
How public, like a frog
To tell your name the livelong day
To an admiring bog!
没任何问题,不过 Rust 编译器也给出了善意的提示,那就是 Result 并没有被使用,这可能意味着存在错误的潜在可能性。
处理返回的错误
fn main() { // --snip-- println!("Searching for {}", config.query); println!("In file {}", config.file_path); if let Err(e) = run(config) { println!("Application error: {e}"); process::exit(1); } }
先回忆下在 build 函数调用时,我们怎么处理错误的?然后与这里的方式做一下对比,是不是发现了一些区别?
没错 if let 的使用让代码变得更简洁,可读性也更加好,原因是,我们并不关注 run 返回的 Ok 值,因此只需要用 if let 去匹配是否存在错误即可。
好了,截止目前,代码看起来越来越美好了,距离我们的目标也只差一个:将主体逻辑代码分离到一个独立的文件 lib.rs 中。
分离逻辑代码到库包中
对于 Rust 的代码组织( 包和模块 )还不熟悉的同学,强烈建议回头温习下这一章。
首先,创建一个 src/lib.rs 文件,然后将所有的非 main 函数都移动到其中。代码大概类似:
#![allow(unused)] fn main() { use std::error::Error; use std::fs; pub struct Config { pub query: String, pub file_path: String, } impl Config { pub fn build(args: &[String]) -> Result<Config, &'static str> { // --snip-- } } pub fn run(config: Config) -> Result<(), Box<dyn Error>> { // --snip-- } }
为了内容的简洁性,这里忽略了具体的实现,下一步就是在 main.rs 中引入 lib.rs 中定义的 Config 类型。
use std::env; use std::process; use minigrep::Config; fn main() { // --snip-- let args: Vec<String> = env::args().collect(); let config = Config::build(&args).unwrap_or_else(|err| { println!("Problem parsing arguments: {err}"); process::exit(1); }); println!("Searching for {}", config.query); println!("In file {}", config.file_path); if let Err(e) = minigrep::run(config) { // --snip-- println!("Application error: {e}"); process::exit(1); } }
很明显,这里的 mingrep::run 的调用,以及 Config 的引入,跟使用其它第三方包已经没有任何区别,也意味着我们成功的将逻辑代码放置到一个独立的库包中,其它包只要引入和调用就行。
呼,一顿书写猛如虎,回头一看。。。这么长的篇幅就写了这么点简单的代码??只能说,我也希望像很多国内的大学教材一样,只要列出定理和解题方法,然后留下足够的习题,就万事大吉了,但是咱们不行。
接下来,到了最喜(令)闻(人)乐(讨)见(厌)的环节:写测试代码,一起来开心吧。
测试驱动开发
开始之前,推荐大家先了解下如何在 Rust 中编写测试代码,这块儿内容不复杂,先了解下有利于本章的继续阅读
在之前的章节中,我们完成了对项目结构的重构,并将进入逻辑代码编程的环节,但在此之前,我们需要先编写一些测试代码,也是最近颇为流行的测试驱动开发模式(TDD, Test Driven Development):
- 编写一个注定失败的测试,并且失败的原因和你指定的一样
- 编写一个成功的测试
- 编写你的逻辑代码,直到通过测试
这三个步骤将在我们的开发过程中不断循环,知道所有的代码都开发完成并成功通过所有测试。
注定失败的测试用例
既然要添加测试,那之前的 println! 语句将没有大的用处,毕竟 println! 存在的目的就是为了让我们看到结果是否正确,而现在测试用例将取而代之。
接下来,在 lib.rs 文件中,添加 tests 模块和 test 函数:
#![allow(unused)] fn main() { #[cfg(test)] mod tests { use super::*; #[test] fn one_result() { let query = "duct"; let contents = "\ Rust: safe, fast, productive. Pick three."; assert_eq!(vec!["safe, fast, productive."], search(query, contents)); } } }
测试用例将在指定的内容中搜索 duct 字符串,目测可得:其中有一行内容是包含有目标字符串的。
但目前为止,还无法运行该测试用例,更何况还想幸灾乐祸的看其失败,原因是 search 函数还没有实现!毕竟是测试驱动、测试先行。
#![allow(unused)] fn main() { // in lib.rs pub fn search<'a>(query: &str, contents: &'a str) -> Vec<&'a str> { vec![] } }
先添加一个简单的 search 函数实现,非常简单粗暴的返回一个空的数组,显而易见测试用例将成功通过,真是一个居心叵测的测试用例!
注意这里生命周期 'a 的使用,之前的章节有详细介绍,不太明白的同学可以回头看看。
喔,这么复杂的代码,都用上生命周期了!嘚瑟两下试试:
$ cargo test
Compiling minigrep v0.1.0 (file:///projects/minigrep)
Finished test [unoptimized + debuginfo] target(s) in 0.97s
Running unittests src/lib.rs (target/debug/deps/minigrep-9cd200e5fac0fc94)
running 1 test
test tests::one_result ... FAILED
failures:
---- tests::one_result stdout ----
thread 'main' panicked at 'assertion failed: `(left == right)`
left: `["safe, fast, productive."]`,
right: `[]`', src/lib.rs:44:9
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
failures:
tests::one_result
test result: FAILED. 0 passed; 1 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s
error: test failed, to rerun pass `--lib`
太棒了!它失败了...
务必成功的测试用例
接着,改进型测试驱动的第二步了:编写注定成功的测试。当然,前提条件是实现我们的 search 函数。它包含以下步骤:
- 遍历迭代
contents的每一行 - 检查该行内容是否包含我们的目标字符串
- 若包含,则放入返回值列表中,否则忽略
- 返回匹配到的返回值列表
遍历迭代每一行
Rust 提供了一个很便利的 lines 方法将目标字符串进行按行分割:
#![allow(unused)] fn main() { // in lib.rs pub fn search<'a>(query: &str, contents: &'a str) -> Vec<&'a str> { for line in contents.lines() { // do something with line } } }
这里的 lines 返回一个迭代器,关于迭代器在后续章节会详细讲解,现在只要知道 for 可以遍历取出迭代器中的值即可。
在每一行中查询目标字符串
#![allow(unused)] fn main() { // in lib.rs pub fn search<'a>(query: &str, contents: &'a str) -> Vec<&'a str> { for line in contents.lines() { if line.contains(query) { // do something with line } } } }
与之前的 lines 函数类似,Rust 的字符串还提供了 contains 方法,用于检查 line 是否包含待查询的 query。
接下来,只要返回合适的值,就可以完成 search 函数的编写。
存储匹配到的结果
简单,创建一个 Vec 动态数组,然后将查询到的每一个 line 推进数组中即可:
#![allow(unused)] fn main() { // in lib.rs pub fn search<'a>(query: &str, contents: &'a str) -> Vec<&'a str> { let mut results = Vec::new(); for line in contents.lines() { if line.contains(query) { results.push(line); } } results } }
至此,search 函数已经完成了既定目标,为了检查功能是否正确,运行下我们之前编写的测试用例:
$ cargo test
Compiling minigrep v0.1.0 (file:///projects/minigrep)
Finished test [unoptimized + debuginfo] target(s) in 1.22s
Running unittests src/lib.rs (target/debug/deps/minigrep-9cd200e5fac0fc94)
running 1 test
test tests::one_result ... ok
test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s
Running unittests src/main.rs (target/debug/deps/minigrep-9cd200e5fac0fc94)
running 0 tests
test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s
Doc-tests minigrep
running 0 tests
test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s
测试通过,意味着我们的代码也完美运行,接下来就是在 run 函数中大显身手了。
在 run 函数中调用 search 函数
#![allow(unused)] fn main() { // in src/lib.rs pub fn run(config: Config) -> Result<(), Box<dyn Error>> { let contents = fs::read_to_string(config.file_path)?; for line in search(&config.query, &contents) { println!("{line}"); } Ok(()) } }
好,再运行下看看结果,看起来我们距离成功从未如此之近!
$ cargo run -- frog poem.txt
Compiling minigrep v0.1.0 (file:///projects/minigrep)
Finished dev [unoptimized + debuginfo] target(s) in 0.38s
Running `target/debug/minigrep frog poem.txt`
How public, like a frog
酷!成功查询到包含 frog 的行,再来试试 body :
$ cargo run -- body poem.txt
Compiling minigrep v0.1.0 (file:///projects/minigrep)
Finished dev [unoptimized + debuginfo] target(s) in 0.0s
Running `target/debug/minigrep body poem.txt`
I'm nobody! Who are you?
Are you nobody, too?
How dreary to be somebody!
完美,三行,一行不少,为了确保万无一失,再来试试查询一个不存在的单词:
cargo run -- monomorphization poem.txt
Compiling minigrep v0.1.0 (file:///projects/minigrep)
Finished dev [unoptimized + debuginfo] target(s) in 0.0s
Running `target/debug/minigrep monomorphization poem.txt`
至此,章节开头的目标已经全部完成,接下来思考一个小问题:如果要为程序加上大小写不敏感的控制命令,由用户进行输入,该怎么实现比较好呢?毕竟在实际搜索查询中,同时支持大小写敏感和不敏感还是很重要的。
答案留待下一章节揭晓。
使用环境变量来增强程序
在上一章节中,留下了一个悬念,该如何实现用户控制的大小写敏感,其实答案很简单,你在其它程序中肯定也遇到过不少,例如如何控制 panic 后的栈展开? Rust 提供的解决方案是通过命令行参数来控制:
RUST_BACKTRACE=1 cargo run
与之类似,我们也可以使用环境变量来控制大小写敏感,例如:
IGNORE_CASE=1 cargo run -- to poem.txt
既然有了目标,那么一起来看看该如何实现吧。
编写大小写不敏感的测试用例
还是遵循之前的规则:测试驱动,这次是对一个新的大小写不敏感函数进行测试 search_case_insensitive。
还记得 TDD 的测试步骤嘛?首先编写一个注定失败的用例:
#![allow(unused)] fn main() { // in src/lib.rs #[cfg(test)] mod tests { use super::*; #[test] fn case_sensitive() { let query = "duct"; let contents = "\ Rust: safe, fast, productive. Pick three. Duct tape."; assert_eq!(vec!["safe, fast, productive."], search(query, contents)); } #[test] fn case_insensitive() { let query = "rUsT"; let contents = "\ Rust: safe, fast, productive. Pick three. Trust me."; assert_eq!( vec!["Rust:", "Trust me."], search_case_insensitive(query, contents) ); } } }
可以看到,这里新增了一个 case_insensitive 测试用例,并对 search_case_insensitive 进行了测试,结果显而易见,函数都没有实现,自然会失败。
接着来实现这个大小写不敏感的搜索函数:
#![allow(unused)] fn main() { pub fn search_case_insensitive<'a>( query: &str, contents: &'a str, ) -> Vec<&'a str> { let query = query.to_lowercase(); let mut results = Vec::new(); for line in contents.lines() { if line.to_lowercase().contains(&query) { results.push(line); } } results } }
跟之前一样,但是引入了一个新的方法 to_lowercase,它会将 line 转换成全小写的字符串,类似的方法在其它语言中也差不多,就不再赘述。
还要注意的是 query 现在是 String 类型,而不是之前的 &str,因为 to_lowercase 返回的是 String。
修改后,再来跑一次测试,看能否通过。
$ cargo test
Compiling minigrep v0.1.0 (file:///projects/minigrep)
Finished test [unoptimized + debuginfo] target(s) in 1.33s
Running unittests src/lib.rs (target/debug/deps/minigrep-9cd200e5fac0fc94)
running 2 tests
test tests::case_insensitive ... ok
test tests::case_sensitive ... ok
test result: ok. 2 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s
Running unittests src/main.rs (target/debug/deps/minigrep-9cd200e5fac0fc94)
running 0 tests
test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s
Doc-tests minigrep
running 0 tests
test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s
Ok,TDD的第二步也完成了,测试通过,接下来就是最后一步,在 run 中调用新的搜索函数。但是在此之前,要新增一个配置项,用于控制是否开启大小写敏感。
#![allow(unused)] fn main() { // in lib.rs pub struct Config { pub query: String, pub file_path: String, pub ignore_case: bool, } }
接下来就是检查该字段,来判断是否启动大小写敏感:
#![allow(unused)] fn main() { pub fn run(config: Config) -> Result<(), Box<dyn Error>> { let contents = fs::read_to_string(config.file_path)?; let results = if config.ignore_case { search_case_insensitive(&config.query, &contents) } else { search(&config.query, &contents) }; for line in results { println!("{line}"); } Ok(()) } }
现在的问题来了,该如何控制这个配置项呢。这个就要借助于章节开头提到的环境变量,好在 Rust 的 env 包提供了相应的方法。
#![allow(unused)] fn main() { use std::env; // --snip-- impl Config { pub fn build(args: &[String]) -> Result<Config, &'static str> { if args.len() < 3 { return Err("not enough arguments"); } let query = args[1].clone(); let file_path = args[2].clone(); let ignore_case = env::var("IGNORE_CASE").is_ok(); Ok(Config { query, file_path, ignore_case, }) } } }
env::var 没啥好说的,倒是 is_ok 值得说道下。该方法是 Result 提供的,用于检查是否有值,有就返回 true,没有则返回 false,刚好完美符合我们的使用场景,因为我们并不关心 Ok<T> 中具体的值。
运行下试试:
$ cargo run -- to poem.txt
Compiling minigrep v0.1.0 (file:///projects/minigrep)
Finished dev [unoptimized + debuginfo] target(s) in 0.0s
Running `target/debug/minigrep to poem.txt`
Are you nobody, too?
How dreary to be somebody!
看起来没有问题,接下来测试下大小写不敏感:
$ IGNORE_CASE=1 cargo run -- to poem.txt
Are you nobody, too?
How dreary to be somebody!
To tell your name the livelong day
To an admiring bog!
大小写不敏感后,查询到的内容明显多了很多,也很符合我们的预期。
最后,给大家留一个小作业:同时使用命令行参数和环境变量的方式来控制大小写不敏感,其中环境变量的优先级更高,也就是两个都设置的情况下,优先使用环境变量的设置。
使用迭代器来改进我们的程序
本章节是可选内容,请大家在看完迭代器章节后,再来阅读
Rust 高级进阶
恭喜你,学会 Rust 基础后,金丹大道已在向你招手,大部分 Rust 代码对你来说都是家常便饭,简单得很。可是,对于一门难度传言在外的语言,怎么可能如此简单的就被征服,最难的生命周期,咱还没见过长啥样呢。
从本章开始,我们将进入 Rust 的进阶学习环节,与基础环节不同的是,由于你已经对 Rust 有了一定的认识,因此我们不会再对很多细节进行翻来覆去的详细讲解,甚至会一带而过。
总之,欢迎来到高级 Rust 的世界,全新的 Boss,全新的装备,你准备好了吗?
生命周期
何为高阶?一个字:难,二个字:很难,七个字:其实也没那么难。至于到底难不难,还是交给各位看官评判吧 :D
大家都知道,生命周期在 Rust 中是最难的部分之一,,因此相关内容被分成了两个章节:基础和进阶,其中基础部分已经在之前学习后,下面一起来看看真正的难字怎么写。
深入生命周期
其实关于生命周期的常用特性,在上一节中,我们已经概括得差不多了,本章主要讲解生命周期的一些高级或者不为人知的特性。对于新手,完全可以跳过本节内容,进行下一章节的学习。
不太聪明的生命周期检查
在 Rust 语言学习中,一个很重要的部分就是阅读一些你可能不经常遇到,但是一旦遇到就难以理解的代码,这些代码往往最令人头疼的就是生命周期,这里我们就来看看一些本以为可以编译,但是却因为生命周期系统不够聪明导致编译失败的代码。
例子 1
#[derive(Debug)] struct Foo; impl Foo { fn mutate_and_share(&mut self) -> &Self { &*self } fn share(&self) {} } fn main() { let mut foo = Foo; let loan = foo.mutate_and_share(); foo.share(); println!("{:?}", loan); }
上面的代码中,foo.mutate_and_share() 虽然借用了 &mut self,但是它最终返回的是一个 &self,然后赋值给 loan,因此理论上来说它最终是进行了不可变借用,同时 foo.share 也进行了不可变借用,那么根据 Rust 的借用规则:多个不可变借用可以同时存在,因此该代码应该编译通过。
事实上,运行代码后,你将看到一个错误:
error[E0502]: cannot borrow `foo` as immutable because it is also borrowed as mutable
--> src/main.rs:12:5
|
11 | let loan = foo.mutate_and_share();
| ---------------------- mutable borrow occurs here
12 | foo.share();
| ^^^^^^^^^^^ immutable borrow occurs here
13 | println!("{:?}", loan);
| ---- mutable borrow later used here
编译器的提示在这里其实有些难以理解,因为可变借用仅在 mutate_and_share 方法内部有效,出了该方法后,就只有返回的不可变借用,因此,按理来说可变借用不应该在 main 的作用范围内存在。
对于这个反直觉的事情,让我们用生命周期来解释下,可能你就很好理解了:
struct Foo; impl Foo { fn mutate_and_share<'a>(&'a mut self) -> &'a Self { &'a *self } fn share<'a>(&'a self) {} } fn main() { 'b: { let mut foo: Foo = Foo; 'c: { let loan: &'c Foo = Foo::mutate_and_share::<'c>(&'c mut foo); 'd: { Foo::share::<'d>(&'d foo); } println!("{:?}", loan); } } }
以上是模拟了编译器的生命周期标注后的代码,可以注意到 &mut foo 和 loan 的生命周期都是 'c。
还记得生命周期消除规则中的第三条吗?因为该规则,导致了 mutate_and_share 方法中,参数 &mut self 和返回值 &self 的生命周期是相同的,因此,若返回值的生命周期在 main 函数有效,那 &mut self 的借用也是在 main 函数有效。
这就解释了可变借用为啥会在 main 函数作用域内有效,最终导致 foo.share() 无法再进行不可变借用。
总结下:&mut self 借用的生命周期和 loan 的生命周期相同,将持续到 println 结束。而在此期间 foo.share() 又进行了一次不可变 &foo 借用,违背了可变借用与不可变借用不能同时存在的规则,最终导致了编译错误。
上述代码实际上完全是正确的,但是因为生命周期系统的“粗糙实现”,导致了编译错误,目前来说,遇到这种生命周期系统不够聪明导致的编译错误,我们也没有太好的办法,只能修改代码去满足它的需求,并期待以后它会更聪明。
例子 2
再来看一个例子:
#![allow(unused)] fn main() { use std::collections::HashMap; use std::hash::Hash; fn get_default<'m, K, V>(map: &'m mut HashMap<K, V>, key: K) -> &'m mut V where K: Clone + Eq + Hash, V: Default, { match map.get_mut(&key) { Some(value) => value, None => { map.insert(key.clone(), V::default()); map.get_mut(&key).unwrap() } } } }
这段代码不能通过编译的原因是编译器未能精确地判断出某个可变借用不再需要,反而谨慎的给该借用安排了一个很大的作用域,结果导致后续的借用失败:
error[E0499]: cannot borrow `*map` as mutable more than once at a time
--> src/main.rs:13:17
|
5 | fn get_default<'m, K, V>(map: &'m mut HashMap<K, V>, key: K) -> &'m mut V
| -- lifetime `'m` defined here
...
10 | match map.get_mut(&key) {
| - ----------------- first mutable borrow occurs here
| _________|
| |
11 | | Some(value) => value,
12 | | None => {
13 | | map.insert(key.clone(), V::default());
| | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ second mutable borrow occurs here
14 | | map.get_mut(&key).unwrap()
15 | | }
16 | | }
| |_________- returning this value requires that `*map` is borrowed for `'m`
分析代码可知在 match map.get_mut(&key) 方法调用完成后,对 map 的可变借用就可以结束了。但从报错看来,编译器不太聪明,它认为该借用会持续到整个 match 语句块的结束(第 16 行处),这便造成了后续借用的失败。
类似的例子还有很多,由于篇幅有限,就不在这里一一列举,如果大家想要阅读更多的类似代码,可以看看<<Rust 代码鉴赏>>一书。
无界生命周期
不安全代码(unsafe)经常会凭空产生引用或生命周期,这些生命周期被称为是 无界(unbound) 的。
无界生命周期往往是在解引用一个裸指针(裸指针 raw pointer)时产生的,换句话说,它是凭空产生的,因为输入参数根本就没有这个生命周期:
#![allow(unused)] fn main() { fn f<'a, T>(x: *const T) -> &'a T { unsafe { &*x } } }
上述代码中,参数 x 是一个裸指针,它并没有任何生命周期,然后通过 unsafe 操作后,它被进行了解引用,变成了一个 Rust 的标准引用类型,该类型必须要有生命周期,也就是 'a。
可以看出 'a 是凭空产生的,因此它是无界生命周期。这种生命周期由于没有受到任何约束,因此它想要多大就多大,这实际上比 'static 要强大。例如 &'static &'a T 是无效类型,但是无界生命周期 &'unbounded &'a T 会被视为 &'a &'a T 从而通过编译检查,因为它可大可小,就像孙猴子的金箍棒一般。
我们在实际应用中,要尽量避免这种无界生命周期。最简单的避免无界生命周期的方式就是在函数声明中运用生命周期消除规则。若一个输出生命周期被消除了,那么必定因为有一个输入生命周期与之对应。
生命周期约束 HRTB
生命周期约束跟特征约束类似,都是通过形如 'a: 'b 的语法,来说明两个生命周期的长短关系。
'a: 'b
假设有两个引用 &'a i32 和 &'b i32,它们的生命周期分别是 'a 和 'b,若 'a >= 'b,则可以定义 'a:'b,表示 'a 至少要活得跟 'b 一样久。
#![allow(unused)] fn main() { struct DoubleRef<'a,'b:'a, T> { r: &'a T, s: &'b T } }
例如上述代码定义一个结构体,它拥有两个引用字段,类型都是泛型 T,每个引用都拥有自己的生命周期,由于我们使用了生命周期约束 'b: 'a,因此 'b 必须活得比 'a 久,也就是结构体中的 s 字段引用的值必须要比 r 字段引用的值活得要久。
T: 'a
表示类型 T 必须比 'a 活得要久:
#![allow(unused)] fn main() { struct Ref<'a, T: 'a> { r: &'a T } }
因为结构体字段 r 引用了 T,因此 r 的生命周期 'a 必须要比 T 的生命周期更短(被引用者的生命周期必须要比引用长)。
在 Rust 1.30 版本之前,该写法是必须的,但是从 1.31 版本开始,编译器可以自动推导 T: 'a 类型的约束,因此我们只需这样写即可:
#![allow(unused)] fn main() { struct Ref<'a, T> { r: &'a T } }
来看一个使用了生命周期约束的综合例子:
#![allow(unused)] fn main() { struct ImportantExcerpt<'a> { part: &'a str, } impl<'a: 'b, 'b> ImportantExcerpt<'a> { fn announce_and_return_part(&'a self, announcement: &'b str) -> &'b str { println!("Attention please: {}", announcement); self.part } } }
上面的例子中必须添加约束 'a: 'b 后,才能成功编译,因为 self.part 的生命周期与 self的生命周期一致,将 &'a 类型的生命周期强行转换为 &'b 类型,会报错,只有在 'a >= 'b 的情况下,'a 才能转换成 'b。
闭包函数的消除规则
先来看一段简单的代码:
#![allow(unused)] fn main() { fn fn_elision(x: &i32) -> &i32 { x } let closure_slision = |x: &i32| -> &i32 { x }; }
乍一看,这段代码比古天乐还平平无奇,能有什么问题呢?来,拄拐走两圈试试:
error: lifetime may not live long enough
--> src/main.rs:39:39
|
39 | let closure = |x: &i32| -> &i32 { x }; // fails
| - - ^ returning this value requires that `'1` must outlive `'2`
| | |
| | let's call the lifetime of this reference `'2`
| let's call the lifetime of this reference `'1`
咦?竟然报错了,明明两个一模一样功能的函数,一个正常编译,一个却报错,错误原因是编译器无法推测返回的引用和传入的引用谁活得更久!
真的是非常奇怪的错误,学过上一节的读者应该都记得这样一条生命周期消除规则:如果函数参数中只有一个引用类型,那该引用的生命周期会被自动分配给所有的返回引用。我们当前的情况完美符合, function 函数的顺利编译通过,就充分说明了问题。
先给出一个结论:这个问题,可能很难被解决,建议大家遇到后,还是老老实实用正常的函数,不要秀闭包了。
对于函数的生命周期而言,它的消除规则之所以能生效是因为它的生命周期完全体现在签名的引用类型上,在函数体中无需任何体现:
#![allow(unused)] fn main() { fn fn_elision(x: &i32) -> &i32 {..} }
因此编译器可以做各种编译优化,也很容易根据参数和返回值进行生命周期的分析,最终得出消除规则。
可是闭包,并没有函数那么简单,它的生命周期分散在参数和闭包函数体中(主要是它没有确切的返回值签名):
#![allow(unused)] fn main() { let closure_slision = |x: &i32| -> &i32 { x }; }
编译器就必须深入到闭包函数体中,去分析和推测生命周期,复杂度因此极具提升:试想一下,编译器该如何从复杂的上下文中分析出参数引用的生命周期和闭包体中生命周期的关系?
由于上述原因(当然,实际情况复杂的多),Rust 语言开发者目前其实是有意针对函数和闭包实现了两种不同的生命周期消除规则。
用
Fn特征解决闭包生命周期之前我们提到了很难解决,但是并没有完全堵死(论文字的艺术- , -) 这不 @Ykong1337 同学就带了一个解决方法,为他点赞!
fn main() { let closure_slision = fun(|x: &i32| -> &i32 { x }); assert_eq!(*closure_slision(&45), 45); // Passed ! } fn fun<T, F: Fn(&T) -> &T>(f: F) -> F { f }
NLL (Non-Lexical Lifetime)
之前我们在引用与借用那一章其实有讲到过这个概念,简单来说就是:引用的生命周期正常来说应该从借用开始一直持续到作用域结束,但是这种规则会让多引用共存的情况变得更复杂:
fn main() { let mut s = String::from("hello"); let r1 = &s; let r2 = &s; println!("{} and {}", r1, r2); // 新编译器中,r1,r2作用域在这里结束 let r3 = &mut s; println!("{}", r3); }
按照上述规则,这段代码将会报错,因为 r1 和 r2 的不可变引用将持续到 main 函数结束,而在此范围内,我们又借用了 r3 的可变引用,这违反了借用的规则:要么多个不可变借用,要么一个可变借用。
好在,该规则从 1.31 版本引入 NLL 后,就变成了:引用的生命周期从借用处开始,一直持续到最后一次使用的地方。
按照最新的规则,我们再来分析一下上面的代码。r1 和 r2 不可变借用在 println! 后就不再使用,因此生命周期也随之结束,那么 r3 的借用就不再违反借用的规则,皆大欢喜。
再来看一段关于 NLL 的代码解释:
#![allow(unused)] fn main() { let mut u = 0i32; let mut v = 1i32; let mut w = 2i32; // lifetime of `a` = α ∪ β ∪ γ let mut a = &mut u; // --+ α. lifetime of `&mut u` --+ lexical "lifetime" of `&mut u`,`&mut u`, `&mut w` and `a` use(a); // | | *a = 3; // <-----------------+ | ... // | a = &mut v; // --+ β. lifetime of `&mut v` | use(a); // | | *a = 4; // <-----------------+ | ... // | a = &mut w; // --+ γ. lifetime of `&mut w` | use(a); // | | *a = 5; // <-----------------+ <--------------------------+ }
这段代码一目了然,a 有三段生命周期:α,β,γ,每一段生命周期都随着当前值的最后一次使用而结束。
在实际项目中,NLL 规则可以大幅减少引用冲突的情况,极大的便利了用户,因此广受欢迎,最终该规则甚至演化成一个独立的项目,未来可能会进一步简化我们的使用,Polonius:
Reborrow 再借用
学完 NLL 后,我们就有了一定的基础,可以继续学习关于借用和生命周期的一个高级内容:再借用。
先来看一段代码:
#[derive(Debug)] struct Point { x: i32, y: i32, } impl Point { fn move_to(&mut self, x: i32, y: i32) { self.x = x; self.y = y; } } fn main() { let mut p = Point { x: 0, y: 0 }; let r = &mut p; let rr: &Point = &*r; println!("{:?}", rr); r.move_to(10, 10); println!("{:?}", r); }
以上代码,大家可能会觉得可变引用 r 和不可变引用 rr 同时存在会报错吧?但是事实上并不会,原因在于 rr 是对 r 的再借用。
对于再借用而言,rr 再借用时不会破坏借用规则,但是你不能在它的生命周期内再使用原来的借用 r,来看看对上段代码的分析:
fn main() { let mut p = Point { x: 0, y: 0 }; let r = &mut p; // reborrow! 此时对`r`的再借用不会导致跟上面的借用冲突 let rr: &Point = &*r; // 再借用`rr`最后一次使用发生在这里,在它的生命周期中,我们并没有使用原来的借用`r`,因此不会报错 println!("{:?}", rr); // 再借用结束后,才去使用原来的借用`r` r.move_to(10, 10); println!("{:?}", r); }
再来看一个例子:
#![allow(unused)] fn main() { use std::vec::Vec; fn read_length(strings: &mut Vec<String>) -> usize { strings.len() } }
如上所示,函数体内对参数的二次借用也是典型的 Reborrow 场景。
那么下面让我们来做件坏事,破坏这条规则,使其报错:
fn main() { let mut p = Point { x: 0, y: 0 }; let r = &mut p; let rr: &Point = &*r; r.move_to(10, 10); println!("{:?}", rr); println!("{:?}", r); }
果然,破坏永远比重建简单 :) 只需要在 rr 再借用的生命周期内使用一次原来的借用 r 即可!
生命周期消除规则补充
在上一节中,我们介绍了三大基础生命周期消除规则,实际上,随着 Rust 的版本进化,该规则也在不断演进,这里再介绍几个常见的消除规则:
impl 块消除
#![allow(unused)] fn main() { impl<'a> Reader for BufReader<'a> { // methods go here // impl内部实际上没有用到'a } }
如果你以前写的impl块长上面这样,同时在 impl 内部的方法中,根本就没有用到 'a,那就可以写成下面的代码形式。
#![allow(unused)] fn main() { impl Reader for BufReader<'_> { // methods go here } }
'_ 生命周期表示 BufReader 有一个不使用的生命周期,我们可以忽略它,无需为它创建一个名称。
歪个楼,有读者估计会发问:既然用不到 'a,为何还要写出来?如果你仔细回忆下上一节的内容,里面有一句专门用粗体标注的文字:生命周期参数也是类型的一部分,因此 BufReader<'a> 是一个完整的类型,在实现它的时候,你不能把 'a 给丢了!
生命周期约束消除
#![allow(unused)] fn main() { // Rust 2015 struct Ref<'a, T: 'a> { field: &'a T } // Rust 2018 struct Ref<'a, T> { field: &'a T } }
在本节的生命周期约束中,也提到过,新版本 Rust 中,上面情况中的 T: 'a 可以被消除掉,当然,你也可以显式的声明,但是会影响代码可读性。关于类似的场景,Rust 团队计划在未来提供更多的消除规则,但是,你懂的,计划未来就等于未知。
一个复杂的例子
下面是一个关于生命周期声明过大的例子,会较为复杂,希望大家能细细阅读,它能帮你对生命周期的理解更加深入。
struct Interface<'a> { manager: &'a mut Manager<'a> } impl<'a> Interface<'a> { pub fn noop(self) { println!("interface consumed"); } } struct Manager<'a> { text: &'a str } struct List<'a> { manager: Manager<'a>, } impl<'a> List<'a> { pub fn get_interface(&'a mut self) -> Interface { Interface { manager: &mut self.manager } } } fn main() { let mut list = List { manager: Manager { text: "hello" } }; list.get_interface().noop(); println!("Interface should be dropped here and the borrow released"); // 下面的调用会失败,因为同时有不可变/可变借用 // 但是Interface在之前调用完成后就应该被释放了 use_list(&list); } fn use_list(list: &List) { println!("{}", list.manager.text); }
运行后报错:
error[E0502]: cannot borrow `list` as immutable because it is also borrowed as mutable // `list`无法被借用,因为已经被可变借用
--> src/main.rs:40:14
|
34 | list.get_interface().noop();
| ---- mutable borrow occurs here // 可变借用发生在这里
...
40 | use_list(&list);
| ^^^^^
| |
| immutable borrow occurs here // 新的不可变借用发生在这
| mutable borrow later used here // 可变借用在这里结束
这段代码看上去并不复杂,实际上难度挺高的,首先在直觉上,list.get_interface() 借用的可变引用,按理来说应该在这行代码结束后,就归还了,但是为什么还能持续到 use_list(&list) 后面呢?
这是因为我们在 get_interface 方法中声明的 lifetime 有问题,该方法的参数的生命周期是 'a,而 List 的生命周期也是 'a,说明该方法至少活得跟 List 一样久,再回到 main 函数中,list 可以活到 main 函数的结束,因此 list.get_interface() 借用的可变引用也会活到 main 函数的结束,在此期间,自然无法再进行借用了。
要解决这个问题,我们需要为 get_interface 方法的参数给予一个不同于 List<'a> 的生命周期 'b,最终代码如下:
struct Interface<'b, 'a: 'b> { manager: &'b mut Manager<'a> } impl<'b, 'a: 'b> Interface<'b, 'a> { pub fn noop(self) { println!("interface consumed"); } } struct Manager<'a> { text: &'a str } struct List<'a> { manager: Manager<'a>, } impl<'a> List<'a> { pub fn get_interface<'b>(&'b mut self) -> Interface<'b, 'a> where 'a: 'b { Interface { manager: &mut self.manager } } } fn main() { let mut list = List { manager: Manager { text: "hello" } }; list.get_interface().noop(); println!("Interface should be dropped here and the borrow released"); // 下面的调用可以通过,因为Interface的生命周期不需要跟list一样长 use_list(&list); } fn use_list(list: &List) { println!("{}", list.manager.text); }
至此,生命周期终于完结,两章超级长的内容,可以满足几乎所有对生命周期的学习目标。学完生命周期,意味着你正式入门了 Rust,只要再掌握几个常用概念,就可以上手写项目了。
&'static 和 T: 'static
Rust 的难点之一就在于它有不少容易混淆的概念,例如 &str 、str 与 String, 再比如本文标题那两位。不过与字符串也有不同,这两位对于普通用户来说往往是无需进行区分的,但是当大家想要深入学习或使用 Rust 时,它们就会成为成功路上的拦路虎了。
与生命周期的其它章节不同,本文短小精悍,阅读过程可谓相当轻松愉快,话不多说,let's go。
'static 在 Rust 中是相当常见的,例如字符串字面值就具有 'static 生命周期:
fn main() { let mark_twain: &str = "Samuel Clemens"; print_author(mark_twain); } fn print_author(author: &'static str) { println!("{}", author); }
除此之外,特征对象的生命周期也是 'static,例如这里所提到的。
除了 &'static 的用法外,我们在另外一种场景中也可以见到 'static 的使用:
use std::fmt::Display; fn main() { let mark_twain = "Samuel Clemens"; print(&mark_twain); } fn print<T: Display + 'static>(message: &T) { println!("{}", message); }
在这里,很明显 'static 是作为生命周期约束来使用了。 那么问题来了, &'static 和 T: 'static 的用法到底有何区别?
&'static
&'static 对于生命周期有着非常强的要求:一个引用必须要活得跟剩下的程序一样久,才能被标注为 &'static。
对于字符串字面量来说,它直接被打包到二进制文件中,永远不会被 drop,因此它能跟程序活得一样久,自然它的生命周期是 'static。
但是,&'static 生命周期针对的仅仅是引用,而不是持有该引用的变量,对于变量来说,还是要遵循相应的作用域规则 :
use std::{slice::from_raw_parts, str::from_utf8_unchecked}; fn get_memory_location() -> (usize, usize) { // “Hello World” 是字符串字面量,因此它的生命周期是 `'static`. // 但持有它的变量 `string` 的生命周期就不一样了,它完全取决于变量作用域,对于该例子来说,也就是当前的函数范围 let string = "Hello World!"; let pointer = string.as_ptr() as usize; let length = string.len(); (pointer, length) // `string` 在这里被 drop 释放 // 虽然变量被释放,无法再被访问,但是数据依然还会继续存活 } fn get_str_at_location(pointer: usize, length: usize) -> &'static str { // 使用裸指针需要 `unsafe{}` 语句块 unsafe { from_utf8_unchecked(from_raw_parts(pointer as *const u8, length)) } } fn main() { let (pointer, length) = get_memory_location(); let message = get_str_at_location(pointer, length); println!( "The {} bytes at 0x{:X} stored: {}", length, pointer, message ); // 如果大家想知道为何处理裸指针需要 `unsafe`,可以试着反注释以下代码 // let message = get_str_at_location(1000, 10); }
上面代码有两点值得注意:
&'static的引用确实可以和程序活得一样久,因为我们通过get_str_at_location函数直接取到了对应的字符串- 持有
&'static引用的变量,它的生命周期受到作用域的限制,大家务必不要搞混了
T: 'static
相比起来,这种形式的约束就有些复杂了。
首先,在以下两种情况下,T: 'static 与 &'static 有相同的约束:T 必须活得和程序一样久。
use std::fmt::Debug; fn print_it<T: Debug + 'static>( input: T) { println!( "'static value passed in is: {:?}", input ); } fn print_it1( input: impl Debug + 'static ) { println!( "'static value passed in is: {:?}", input ); } fn main() { let i = 5; print_it(&i); print_it1(&i); }
以上代码会报错,原因很简单: &i 的生命周期无法满足 'static 的约束,如果大家将 i 修改为常量,那自然一切 OK。
见证奇迹的时候,请不要眨眼,现在我们来稍微修改下 print_it 函数:
use std::fmt::Debug; fn print_it<T: Debug + 'static>( input: &T) { println!( "'static value passed in is: {:?}", input ); } fn main() { let i = 5; print_it(&i); }
这段代码竟然不报错了!原因在于我们约束的是 T,但是使用的却是它的引用 &T,换而言之,我们根本没有直接使用 T,因此编译器就没有去检查 T 的生命周期约束!它只要确保 &T 的生命周期符合规则即可,在上面代码中,它自然是符合的。
再来看一个例子:
use std::fmt::Display; fn main() { let r1; let r2; { static STATIC_EXAMPLE: i32 = 42; r1 = &STATIC_EXAMPLE; let x = "&'static str"; r2 = x; // r1 和 r2 持有的数据都是 'static 的,因此在花括号结束后,并不会被释放 } println!("&'static i32: {}", r1); // -> 42 println!("&'static str: {}", r2); // -> &'static str let r3: &str; { let s1 = "String".to_string(); // s1 虽然没有 'static 生命周期,但是它依然可以满足 T: 'static 的约束 // 充分说明这个约束是多么的弱。。 static_bound(&s1); // s1 是 String 类型,没有 'static 的生命周期,因此下面代码会报错 r3 = &s1; // s1 在这里被 drop } println!("{}", r3); } fn static_bound<T: Display + 'static>(t: &T) { println!("{}", t); }
static 到底针对谁?
大家有没有想过,到底是 &'static 这个引用还是该引用指向的数据活得跟程序一样久呢?
答案是引用指向的数据,而引用本身是要遵循其作用域范围的,我们来简单验证下:
fn main() { { let static_string = "I'm in read-only memory"; println!("static_string: {}", static_string); // 当 `static_string` 超出作用域时,该引用不能再被使用,但是数据依然会存在于 binary 所占用的内存中 } println!("static_string reference remains alive: {}", static_string); }
以上代码不出所料会报错,原因在于虽然字符串字面量 "I'm in read-only memory" 的生命周期是 'static,但是持有它的引用并不是,它的作用域在内部花括号 } 处就结束了。
课后练习
Rust By Practice,支持代码在线编辑和运行,并提供详细的习题解答。(本节暂无习题解答)
总结
总之, &'static 和 T: 'static 大体上相似,相比起来,后者的使用形式会更加复杂一些。
至此,相信大家对于 'static 和 T: 'static 也有了清晰的理解,那么我们应该如何使用它们呢?
作为经验之谈,可以这么来:
- 如果你需要添加
&'static来让代码工作,那很可能是设计上出问题了 - 如果你希望满足和取悦编译器,那就使用
T: 'static,很多时候它都能解决问题
一个小知识,在 Rust 标准库中,有 48 处用到了 &'static ,112 处用到了
T: 'static,看来取悦编译器不仅仅是菜鸟需要的,高手也经常用到 :)
函数式编程
罗马不是一天建成的,编程语言亦是如此,每一门编程语言在借鉴前辈的同时,也会提出自己独有的特性,Rust 即是如此。当站在巨人肩膀上时,一个人所能看到的就更高更远,恰好,我们看到了函数式语言的优秀特性,例如:
- 使用函数作为参数进行传递
- 使用函数作为函数返回值
- 将函数赋值给变量
见猎心喜,我们忍不住就借鉴了过来,于是你能看到本章的内容,天下语言一大。。。跑题了。
关于函数式编程到底是什么的争论由来已久,本章节并不会踏足这个泥潭,因此我们在这里主要关注的是函数式特性:
- 闭包 Closure
- 迭代器 Iterator
- 模式匹配
- 枚举
其中后两个在前面章节我们已经深入学习过,因此本章的重点就是闭包和迭代器,这些函数式特性可以让代码的可读性和易写性大幅提升。对于 Rust 语言来说,掌握这两者就相当于你同时拥有了倚天剑屠龙刀,威力无穷。
闭包 Closure
闭包这个词语由来已久,自上世纪 60 年代就由 Scheme 语言引进之后,被广泛用于函数式编程语言中,进入 21 世纪后,各种现代化的编程语言也都不约而同地把闭包作为核心特性纳入到语言设计中来。那么到底何为闭包?
闭包是一种匿名函数,它可以赋值给变量也可以作为参数传递给其它函数,不同于函数的是,它允许捕获调用者作用域中的值,例如:
fn main() { let x = 1; let sum = |y| x + y; assert_eq!(3, sum(2)); }
上面的代码展示了非常简单的闭包 sum,它拥有一个入参 y,同时捕获了作用域中的 x 的值,因此调用 sum(2) 意味着将 2(参数 y)跟 1(x)进行相加,最终返回它们的和:3。
可以看到 sum 非常符合闭包的定义:可以赋值给变量,允许捕获调用者作用域中的值。
使用闭包来简化代码
传统函数实现
想象一下,我们要进行健身,用代码怎么实现(写代码什么鬼,健身难道不应该去健身房嘛?答曰:健身太累了,还是虚拟健身好,点到为止)?这里是我的想法:
use std::thread; use std::time::Duration; // 开始健身,好累,我得发出声音:muuuu... fn muuuuu(intensity: u32) -> u32 { println!("muuuu....."); thread::sleep(Duration::from_secs(2)); intensity } fn workout(intensity: u32, random_number: u32) { if intensity < 25 { println!( "今天活力满满,先做 {} 个俯卧撑!", muuuuu(intensity) ); println!( "旁边有妹子在看,俯卧撑太low,再来 {} 组卧推!", muuuuu(intensity) ); } else if random_number == 3 { println!("昨天练过度了,今天还是休息下吧!"); } else { println!( "昨天练过度了,今天干干有氧,跑步 {} 分钟!", muuuuu(intensity) ); } } fn main() { // 强度 let intensity = 10; // 随机值用来决定某个选择 let random_number = 7; // 开始健身 workout(intensity, random_number); }
可以看到,在健身时我们根据想要的强度来调整具体的动作,然后调用 muuuuu 函数来开始健身。这个程序本身很简单,没啥好说的,但是假如未来不用 muuuuu 函数了,是不是得把所有 muuuuu 都替换成,比如说 woooo ?如果 muuuuu 出现了几十次,那意味着我们要修改几十处地方。
函数变量实现
一个可行的办法是,把函数赋值给一个变量,然后通过变量调用:
use std::thread; use std::time::Duration; // 开始健身,好累,我得发出声音:muuuu... fn muuuuu(intensity: u32) -> u32 { println!("muuuu....."); thread::sleep(Duration::from_secs(2)); intensity } fn workout(intensity: u32, random_number: u32) { let action = muuuuu; if intensity < 25 { println!( "今天活力满满, 先做 {} 个俯卧撑!", action(intensity) ); println!( "旁边有妹子在看,俯卧撑太low, 再来 {} 组卧推!", action(intensity) ); } else if random_number == 3 { println!("昨天练过度了,今天还是休息下吧!"); } else { println!( "昨天练过度了,今天干干有氧, 跑步 {} 分钟!", action(intensity) ); } } fn main() { // 强度 let intensity = 10; // 随机值用来决定某个选择 let random_number = 7; // 开始健身 workout(intensity, random_number); }
经过上面修改后,所有的调用都通过 action 来完成,若未来声(动)音(作)变了,只要修改为 let action = woooo 即可。
但是问题又来了,若 intensity 也变了怎么办?例如变成 action(intensity + 1),那你又得哐哐哐修改几十处调用。
该怎么办?没太好的办法了,只能祭出大杀器:闭包。
闭包实现
上面提到 intensity 要是变化怎么办,简单,使用闭包来捕获它,这是我们的拿手好戏:
use std::thread; use std::time::Duration; fn workout(intensity: u32, random_number: u32) { let action = || { println!("muuuu....."); thread::sleep(Duration::from_secs(2)); intensity }; if intensity < 25 { println!( "今天活力满满,先做 {} 个俯卧撑!", action() ); println!( "旁边有妹子在看,俯卧撑太low,再来 {} 组卧推!", action() ); } else if random_number == 3 { println!("昨天练过度了,今天还是休息下吧!"); } else { println!( "昨天练过度了,今天干干有氧,跑步 {} 分钟!", action() ); } } fn main() { // 动作次数 let intensity = 10; // 随机值用来决定某个选择 let random_number = 7; // 开始健身 workout(intensity, random_number); }
在上面代码中,无论你要修改什么,只要修改闭包 action 的实现即可,其它地方只负责调用,完美解决了我们的问题!
Rust 闭包在形式上借鉴了 Smalltalk 和 Ruby 语言,与函数最大的不同就是它的参数是通过 |parm1| 的形式进行声明,如果是多个参数就 |param1, param2,...|, 下面给出闭包的形式定义:
#![allow(unused)] fn main() { |param1, param2,...| { 语句1; 语句2; 返回表达式 } }
如果只有一个返回表达式的话,定义可以简化为:
#![allow(unused)] fn main() { |param1| 返回表达式 }
上例中还有两点值得注意:
- 闭包中最后一行表达式返回的值,就是闭包执行后的返回值,因此
action()调用返回了intensity的值10 let action = ||...只是把闭包赋值给变量action,并不是把闭包执行后的结果赋值给action,因此这里action就相当于闭包函数,可以跟函数一样进行调用:action()
闭包的类型推导
Rust 是静态语言,因此所有的变量都具有类型,但是得益于编译器的强大类型推导能力,在很多时候我们并不需要显式地去声明类型,但是显然函数并不在此列,必须手动为函数的所有参数和返回值指定类型,原因在于函数往往会作为 API 提供给你的用户,因此你的用户必须在使用时知道传入参数的类型和返回值类型。
与函数相反,闭包并不会作为 API 对外提供,因此它可以享受编译器的类型推导能力,无需标注参数和返回值的类型。
为了增加代码可读性,有时候我们会显式地给类型进行标注,出于同样的目的,也可以给闭包标注类型:
#![allow(unused)] fn main() { let sum = |x: i32, y: i32| -> i32 { x + y } }
与之相比,不标注类型的闭包声明会更简洁些:let sum = |x, y| x + y,需要注意的是,针对 sum 闭包,如果你只进行了声明,但是没有使用,编译器会提示你为 x, y 添加类型标注,因为它缺乏必要的上下文:
#![allow(unused)] fn main() { let sum = |x, y| x + y; let v = sum(1, 2); }
这里我们使用了 sum,同时把 1 传给了 x,2 传给了 y,因此编译器才可以推导出 x,y 的类型为 i32。
下面展示了同一个功能的函数和闭包实现形式:
#![allow(unused)] fn main() { fn add_one_v1 (x: u32) -> u32 { x + 1 } let add_one_v2 = |x: u32| -> u32 { x + 1 }; let add_one_v3 = |x| { x + 1 }; let add_one_v4 = |x| x + 1 ; }
可以看出第一行的函数和后面的闭包其实在形式上是非常接近的,同时三种不同的闭包也展示了三种不同的使用方式:省略参数、返回值类型和花括号对。
虽然类型推导很好用,但是它不是泛型,当编译器推导出一种类型后,它就会一直使用该类型:
#![allow(unused)] fn main() { let example_closure = |x| x; let s = example_closure(String::from("hello")); let n = example_closure(5); }
首先,在 s 中,编译器为 x 推导出类型 String,但是紧接着 n 试图用 5 这个整型去调用闭包,跟编译器之前推导的 String 类型不符,因此报错:
error[E0308]: mismatched types
--> src/main.rs:5:29
|
5 | let n = example_closure(5);
| ^
| |
| expected struct `String`, found integer // 期待String类型,却发现一个整数
| help: try using a conversion method: `5.to_string()`
结构体中的闭包
假设我们要实现一个简易缓存,功能是获取一个值,然后将其缓存起来,那么可以这样设计:
- 一个闭包用于获取值
- 一个变量,用于存储该值
可以使用结构体来代表缓存对象,最终设计如下:
#![allow(unused)] fn main() { struct Cacher<T> where T: Fn(u32) -> u32, { query: T, value: Option<u32>, } }
等等,我都跟着这本教程学完 Rust 基础了,为何还有我不认识的东东?Fn(u32) -> u32 是什么鬼?别急,先回答你第一个问题:骚年,too young too naive,你以为 Rust 的语法特性就基础入门那一些吗?太年轻了!如果是长征,你才刚到赤水河。
其实,可以看得出这一长串是 T 的特征约束,再结合之前的已知信息:query 是一个闭包,大概可以推测出,Fn(u32) -> u32 是一个特征,用来表示 T 是一个闭包类型?Bingo,恭喜你,答对了!
那为什么不用具体的类型来标注 query 呢?原因很简单,每一个闭包实例都有独属于自己的类型,即使于两个签名一模一样的闭包,它们的类型也是不同的,因此你无法用一个统一的类型来标注 query 闭包。
而标准库提供的 Fn 系列特征,再结合特征约束,就能很好的解决了这个问题. T: Fn(u32) -> u32 意味着 query 的类型是 T,该类型必须实现了相应的闭包特征 Fn(u32) -> u32。从特征的角度来看它长得非常反直觉,但是如果从闭包的角度来看又极其符合直觉,不得不佩服 Rust 团队的鬼才设计。。。
特征 Fn(u32) -> u32 从表面来看,就对闭包形式进行了显而易见的限制:该闭包拥有一个u32类型的参数,同时返回一个u32类型的值。
需要注意的是,其实 Fn 特征不仅仅适用于闭包,还适用于函数,因此上面的
query字段除了使用闭包作为值外,还能使用一个具名的函数来作为它的值
接着,为缓存实现方法:
#![allow(unused)] fn main() { impl<T> Cacher<T> where T: Fn(u32) -> u32, { fn new(query: T) -> Cacher<T> { Cacher { query, value: None, } } // 先查询缓存值 `self.value`,若不存在,则调用 `query` 加载 fn value(&mut self, arg: u32) -> u32 { match self.value { Some(v) => v, None => { let v = (self.query)(arg); self.value = Some(v); v } } } } }
上面的缓存有一个很大的问题:只支持 u32 类型的值,若我们想要缓存 &str 类型,显然就行不通了,因此需要将 u32 替换成泛型 E,该练习就留给读者自己完成,具体代码可以参考这里
捕获作用域中的值
在之前代码中,我们一直在用闭包的匿名函数特性(赋值给变量),然而闭包还拥有一项函数所不具备的特性:捕获作用域中的值。
fn main() { let x = 4; let equal_to_x = |z| z == x; let y = 4; assert!(equal_to_x(y)); }
上面代码中,x 并不是闭包 equal_to_x 的参数,但是它依然可以去使用 x,因为 equal_to_x 在 x 的作用域范围内。
对于函数来说,就算你把函数定义在 main 函数体中,它也不能访问 x:
fn main() { let x = 4; fn equal_to_x(z: i32) -> bool { z == x } let y = 4; assert!(equal_to_x(y)); }
报错如下:
error[E0434]: can't capture dynamic environment in a fn item // 在函数中无法捕获动态的环境
--> src/main.rs:5:14
|
5 | z == x
| ^
|
= help: use the `|| { ... }` closure form instead // 使用闭包替代
如上所示,编译器准确地告诉了我们错误,甚至同时给出了提示:使用闭包来替代函数,这种聪明令我有些无所适从,总感觉会显得我很笨。
闭包对内存的影响
当闭包从环境中捕获一个值时,会分配内存去存储这些值。对于有些场景来说,这种额外的内存分配会成为一种负担。与之相比,函数就不会去捕获这些环境值,因此定义和使用函数不会拥有这种内存负担。
三种 Fn 特征
闭包捕获变量有三种途径,恰好对应函数参数的三种传入方式:转移所有权、可变借用、不可变借用,因此相应的 Fn 特征也有三种:
FnOnce,该类型的闭包会拿走被捕获变量的所有权。Once顾名思义,说明该闭包只能运行一次:
fn fn_once<F>(func: F) where F: FnOnce(usize) -> bool, { println!("{}", func(3)); println!("{}", func(4)); } fn main() { let x = vec![1, 2, 3]; fn_once(|z|{z == x.len()}) }
仅实现 FnOnce 特征的闭包在调用时会转移所有权,所以显然不能对已失去所有权的闭包变量进行二次调用:
error[E0382]: use of moved value: `func`
--> src\main.rs:6:20
|
1 | fn fn_once<F>(func: F)
| ---- move occurs because `func` has type `F`, which does not implement the `Copy` trait
// 因为`func`的类型是没有实现`Copy`特性的 `F`,所以发生了所有权的转移
...
5 | println!("{}", func(3));
| ------- `func` moved due to this call // 转移在这
6 | println!("{}", func(4));
| ^^^^ value used here after move // 转移后再次用
|
这里面有一个很重要的提示,因为 F 没有实现 Copy 特征,所以会报错,那么我们添加一个约束,试试实现了 Copy 的闭包:
fn fn_once<F>(func: F) where F: FnOnce(usize) -> bool + Copy,// 改动在这里 { println!("{}", func(3)); println!("{}", func(4)); } fn main() { let x = vec![1, 2, 3]; fn_once(|z|{z == x.len()}) }
上面代码中,func 的类型 F 实现了 Copy 特征,调用时使用的将是它的拷贝,所以并没有发生所有权的转移。
true
false
如果你想强制闭包取得捕获变量的所有权,可以在参数列表前添加 move 关键字,这种用法通常用于闭包的生命周期大于捕获变量的生命周期时,例如将闭包返回或移入其他线程。
#![allow(unused)] fn main() { use std::thread; let v = vec![1, 2, 3]; let handle = thread::spawn(move || { println!("Here's a vector: {:?}", v); }); handle.join().unwrap(); }
FnMut,它以可变借用的方式捕获了环境中的值,因此可以修改该值:
fn main() { let mut s = String::new(); let update_string = |str| s.push_str(str); update_string("hello"); println!("{:?}",s); }
在闭包中,我们调用 s.push_str 去改变外部 s 的字符串值,因此这里捕获了它的可变借用,运行下试试:
error[E0596]: cannot borrow `update_string` as mutable, as it is not declared as mutable
--> src/main.rs:5:5
|
4 | let update_string = |str| s.push_str(str);
| ------------- - calling `update_string` requires mutable binding due to mutable borrow of `s`
| |
| help: consider changing this to be mutable: `mut update_string`
5 | update_string("hello");
| ^^^^^^^^^^^^^ cannot borrow as mutable
虽然报错了,但是编译器给出了非常清晰的提示,想要在闭包内部捕获可变借用,需要把该闭包声明为可变类型,也就是 update_string 要修改为 mut update_string:
fn main() { let mut s = String::new(); let mut update_string = |str| s.push_str(str); update_string("hello"); println!("{:?}",s); }
这种写法有点反直觉,相比起来前面的 move 更符合使用和阅读习惯。但是如果你忽略 update_string 的类型,仅仅把它当成一个普通变量,那么这种声明就比较合理了。
再来看一个复杂点的:
fn main() { let mut s = String::new(); let update_string = |str| s.push_str(str); exec(update_string); println!("{:?}",s); } fn exec<'a, F: FnMut(&'a str)>(mut f: F) { f("hello") }
这段代码非常清晰的说明了 update_string 实现了 FnMut 特征
Fn特征,它以不可变借用的方式捕获环境中的值 让我们把上面的代码中exec的F泛型参数类型修改为Fn(&'a str):
fn main() { let mut s = String::new(); let update_string = |str| s.push_str(str); exec(update_string); println!("{:?}",s); } fn exec<'a, F: Fn(&'a str)>(mut f: F) { f("hello") }
然后运行看看结果:
error[E0525]: expected a closure that implements the `Fn` trait, but this closure only implements `FnMut`
--> src/main.rs:4:26 // 期望闭包实现的是`Fn`特征,但是它只实现了`FnMut`特征
|
4 | let update_string = |str| s.push_str(str);
| ^^^^^^-^^^^^^^^^^^^^^
| | |
| | closure is `FnMut` because it mutates the variable `s` here
| this closure implements `FnMut`, not `Fn` //闭包实现的是FnMut,而不是Fn
5 |
6 | exec(update_string);
| ---- the requirement to implement `Fn` derives from here
从报错中很清晰的看出,我们的闭包实现的是 FnMut 特征,需要的是可变借用,但是在 exec 中却给它标注了 Fn 特征,因此产生了不匹配,再来看看正确的不可变借用方式:
fn main() { let s = "hello, ".to_string(); let update_string = |str| println!("{},{}",s,str); exec(update_string); println!("{:?}",s); } fn exec<'a, F: Fn(String) -> ()>(f: F) { f("world".to_string()) }
在这里,因为无需改变 s,因此闭包中只对 s 进行了不可变借用,那么在 exec 中,将其标记为 Fn 特征就完全正确。
move 和 Fn
在上面,我们讲到了 move 关键字对于 FnOnce 特征的重要性,但是实际上使用了 move 的闭包依然可能实现了 Fn 或 FnMut 特征。
因为,一个闭包实现了哪种 Fn 特征取决于该闭包如何使用被捕获的变量,而不是取决于闭包如何捕获它们。move 本身强调的就是后者,闭包如何捕获变量:
fn main() { let s = String::new(); let update_string = move || println!("{}",s); exec(update_string); } fn exec<F: FnOnce()>(f: F) { f() }
我们在上面的闭包中使用了 move 关键字,所以我们的闭包捕获了它,但是由于闭包对 s 的使用仅仅是不可变借用,因此该闭包实际上还实现了 Fn 特征。
细心的读者肯定发现我在上段中使用了一个 还 字,这是什么意思呢?因为该闭包不仅仅实现了 FnOnce 特征,还实现了 Fn 特征,将代码修改成下面这样,依然可以编译:
fn main() { let s = String::new(); let update_string = move || println!("{}",s); exec(update_string); } fn exec<F: Fn()>(f: F) { f() }
三种 Fn 的关系
实际上,一个闭包并不仅仅实现某一种 Fn 特征,规则如下:
- 所有的闭包都自动实现了
FnOnce特征,因此任何一个闭包都至少可以被调用一次 - 没有移出所捕获变量的所有权的闭包自动实现了
FnMut特征 - 不需要对捕获变量进行改变的闭包自动实现了
Fn特征
用一段代码来简单诠释上述规则:
fn main() { let s = String::new(); let update_string = || println!("{}",s); exec(update_string); exec1(update_string); exec2(update_string); } fn exec<F: FnOnce()>(f: F) { f() } fn exec1<F: FnMut()>(mut f: F) { f() } fn exec2<F: Fn()>(f: F) { f() }
虽然,闭包只是对 s 进行了不可变借用,实际上,它可以适用于任何一种 Fn 特征:三个 exec 函数说明了一切。强烈建议读者亲自动手试试各种情况下使用的 Fn 特征,更有助于加深这方面的理解。
关于第二条规则,有如下示例:
fn main() { let mut s = String::new(); let update_string = |str| -> String {s.push_str(str); s }; exec(update_string); } fn exec<'a, F: FnMut(&'a str) -> String>(mut f: F) { f("hello"); }
5 | let update_string = |str| -> String {s.push_str(str); s };
| ^^^^^^^^^^^^^^^ - closure is `FnOnce` because it moves the variable `s` out of its environment
| // 闭包实现了`FnOnce`,因为它从捕获环境中移出了变量`s`
| |
| this closure implements `FnOnce`, not `FnMut`
此例中,闭包从捕获环境中移出了变量 s 的所有权,因此这个闭包仅自动实现了 FnOnce,未实现 FnMut 和 Fn。再次印证之前讲的一个闭包实现了哪种 Fn 特征取决于该闭包如何使用被捕获的变量,而不是取决于闭包如何捕获它们,跟是否使用 move 没有必然联系。
如果还是有疑惑?没关系,我们来看看这三个特征的简化版源码:
#![allow(unused)] fn main() { pub trait Fn<Args> : FnMut<Args> { extern "rust-call" fn call(&self, args: Args) -> Self::Output; } pub trait FnMut<Args> : FnOnce<Args> { extern "rust-call" fn call_mut(&mut self, args: Args) -> Self::Output; } pub trait FnOnce<Args> { type Output; extern "rust-call" fn call_once(self, args: Args) -> Self::Output; } }
看到没?从特征约束能看出来 Fn 的前提是实现 FnMut,FnMut 的前提是实现 FnOnce,因此要实现 Fn 就要同时实现 FnMut 和 FnOnce,这段源码从侧面印证了之前规则的正确性。
从源码中还能看出一点:Fn 获取 &self,FnMut 获取 &mut self,而 FnOnce 获取 self。
在实际项目中,建议先使用 Fn 特征,然后编译器会告诉你正误以及该如何选择。
闭包作为函数返回值
看到这里,相信大家对于如何使用闭包作为函数参数,已经很熟悉了,但是如果要使用闭包作为函数返回值,该如何做?
先来看一段代码:
#![allow(unused)] fn main() { fn factory() -> Fn(i32) -> i32 { let num = 5; |x| x + num } let f = factory(); let answer = f(1); assert_eq!(6, answer); }
上面这段代码看起来还是蛮正常的,用 Fn(i32) -> i32 特征来代表 |x| x + num,非常合理嘛,肯定可以编译通过, 可惜理想总是难以照进现实,编译器给我们报了一大堆错误,先挑几个重点来看看:
fn factory<T>() -> Fn(i32) -> i32 {
| ^^^^^^^^^^^^^^ doesn't have a size known at compile-time // 该类型在编译器没有固定的大小
Rust 要求函数的参数和返回类型,必须有固定的内存大小,例如 i32 就是 4 个字节,引用类型是 8 个字节,总之,绝大部分类型都有固定的大小,但是不包括特征,因为特征类似接口,对于编译器来说,无法知道它后面藏的真实类型是什么,因为也无法得知具体的大小。
同样,我们也无法知道闭包的具体类型,该怎么办呢?再看看报错提示:
help: use `impl Fn(i32) -> i32` as the return type, as all return paths are of type `[closure@src/main.rs:11:5: 11:21]`, which implements `Fn(i32) -> i32`
|
8 | fn factory<T>() -> impl Fn(i32) -> i32 {
嗯,编译器提示我们加一个 impl 关键字,哦,这样一说,读者可能就想起来了,impl Trait 可以用来返回一个实现了指定特征的类型,那么这里 impl Fn(i32) -> i32 的返回值形式,说明我们要返回一个闭包类型,它实现了 Fn(i32) -> i32 特征。
完美解决,但是,在特征那一章,我们提到过,impl Trait 的返回方式有一个非常大的局限,就是你只能返回同样的类型,例如:
#![allow(unused)] fn main() { fn factory(x:i32) -> impl Fn(i32) -> i32 { let num = 5; if x > 1{ move |x| x + num } else { move |x| x - num } } }
运行后,编译器报错:
error[E0308]: `if` and `else` have incompatible types
--> src/main.rs:15:9
|
12 | / if x > 1{
13 | | move |x| x + num
| | ---------------- expected because of this
14 | | } else {
15 | | move |x| x - num
| | ^^^^^^^^^^^^^^^^ expected closure, found a different closure
16 | | }
| |_____- `if` and `else` have incompatible types
|
嗯,提示很清晰:if 和 else 分支中返回了不同的闭包类型,这就很奇怪了,明明这两个闭包长的一样的,好在细心的读者应该回想起来,本章节前面咱们有提到:就算签名一样的闭包,类型也是不同的,因此在这种情况下,就无法再使用 impl Trait 的方式去返回闭包。
怎么办?再看看编译器提示,里面有这样一行小字:
= help: consider boxing your closure and/or using it as a trait object
哦,相信你已经恍然大悟,可以用特征对象!只需要用 Box 的方式即可实现:
#![allow(unused)] fn main() { fn factory(x:i32) -> Box<dyn Fn(i32) -> i32> { let num = 5; if x > 1{ Box::new(move |x| x + num) } else { Box::new(move |x| x - num) } } }
至此,闭包作为函数返回值就已完美解决,若以后你再遇到报错时,一定要仔细阅读编译器的提示,很多时候,转角都能遇到爱。
闭包的生命周期
这块儿内容在进阶生命周期章节中有讲,这里就不再赘述,读者可移步此处进行回顾。
迭代器 Iterator
如果你询问一个 Rust 资深开发:写 Rust 项目最需要掌握什么?相信迭代器往往就是答案之一。无论你是编程新手亦或是高手,实际上大概率都用过迭代器,虽然自己可能并没有意识到这一点:)
迭代器允许我们迭代一个连续的集合,例如数组、动态数组 Vec、HashMap 等,在此过程中,只需关心集合中的元素如何处理,而无需关心如何开始、如何结束、按照什么样的索引去访问等问题。
For 循环与迭代器
从用途来看,迭代器跟 for 循环颇为相似,都是去遍历一个集合,但是实际上它们存在不小的差别,其中最主要的差别就是:是否通过索引来访问集合。
例如以下的 JS 代码就是一个循环:
let arr = [1, 2, 3];
for (let i = 0; i < arr.length; i++) {
console.log(arr[i]);
}
在上面代码中,我们设置索引的开始点和结束点,然后再通过索引去访问元素 arr[i],这就是典型的循环,来对比下 Rust 中的 for:
#![allow(unused)] fn main() { let arr = [1, 2, 3]; for v in arr { println!("{}",v); } }
首先,不得不说这两语法还挺像!与 JS 循环不同,Rust中没有使用索引,它把 arr 数组当成一个迭代器,直接去遍历其中的元素,从哪里开始,从哪里结束,都无需操心。因此严格来说,Rust 中的 for 循环是编译器提供的语法糖,最终还是对迭代器中的元素进行遍历。
那又有同学要发问了,在 Rust 中数组是迭代器吗?因为在之前的代码中直接对数组 arr 进行了迭代,答案是 No。那既然数组不是迭代器,为啥咱可以对它的元素进行迭代呢?
简而言之就是数组实现了 IntoIterator 特征,Rust 通过 for 语法糖,自动把实现了该特征的数组类型转换为迭代器(你也可以为自己的集合类型实现此特征),最终让我们可以直接对一个数组进行迭代,类似的还有:
#![allow(unused)] fn main() { for i in 1..10 { println!("{}", i); } }
直接对数值序列进行迭代,也是很常见的使用方式。
IntoIterator 特征拥有一个 into_iter 方法,因此我们还可以显式的把数组转换成迭代器:
#![allow(unused)] fn main() { let arr = [1, 2, 3]; for v in arr.into_iter() { println!("{}", v); } }
迭代器是函数语言的核心特性,它赋予了 Rust 远超于循环的强大表达能力,我们将在本章中一一为大家进行展现。
惰性初始化
在 Rust 中,迭代器是惰性的,意味着如果你不使用它,那么它将不会发生任何事:
#![allow(unused)] fn main() { let v1 = vec![1, 2, 3]; let v1_iter = v1.iter(); for val in v1_iter { println!("{}", val); } }
在 for 循环之前,我们只是简单的创建了一个迭代器 v1_iter,此时不会发生任何迭代行为,只有在 for 循环开始后,迭代器才会开始迭代其中的元素,最后打印出来。
这种惰性初始化的方式确保了创建迭代器不会有任何额外的性能损耗,其中的元素也不会被消耗,只有使用到该迭代器的时候,一切才开始。
next 方法
对于 for 如何遍历迭代器,还有一个问题,它如何取出迭代器中的元素?
先来看一个特征:
#![allow(unused)] fn main() { pub trait Iterator { type Item; fn next(&mut self) -> Option<Self::Item>; // 省略其余有默认实现的方法 } }
呦,该特征竟然和迭代器 iterator 同名,难不成。。。没错,它们就是有一腿。迭代器之所以成为迭代器,就是因为实现了 Iterator 特征,要实现该特征,最主要的就是实现其中的 next 方法,该方法控制如何从集合中取值,最终返回值的类型是关联类型 Item。
因此,之前问题的答案已经很明显:for 循环通过不停调用迭代器上的 next 方法,来获取迭代器中的元素。
既然 for 可以调用 next 方法,是不是意味着我们也可以?来试试:
fn main() { let arr = [1, 2, 3]; let mut arr_iter = arr.into_iter(); assert_eq!(arr_iter.next(), Some(1)); assert_eq!(arr_iter.next(), Some(2)); assert_eq!(arr_iter.next(), Some(3)); assert_eq!(arr_iter.next(), None); }
果不其然,将 arr 转换成迭代器后,通过调用其上的 next 方法,我们获取了 arr 中的元素,有两点需要注意:
next方法返回的是Option类型,当有值时返回Some(i32),无值时返回None- 遍历是按照迭代器中元素的排列顺序依次进行的,因此我们严格按照数组中元素的顺序取出了
Some(1),Some(2),Some(3) - 手动迭代必须将迭代器声明为
mut可变,因为调用next会改变迭代器其中的状态数据(当前遍历的位置等),而for循环去迭代则无需标注mut,因为它会帮我们自动完成
总之,next 方法对迭代器的遍历是消耗性的,每次消耗它一个元素,最终迭代器中将没有任何元素,只能返回 None。
例子:模拟实现 for 循环
因为 for 循环是迭代器的语法糖,因此我们完全可以通过迭代器来模拟实现它:
#![allow(unused)] fn main() { let values = vec![1, 2, 3]; { let result = match IntoIterator::into_iter(values) { mut iter => loop { match iter.next() { Some(x) => { println!("{}", x); }, None => break, } }, }; result } }
IntoIterator::into_iter 是使用完全限定的方式去调用 into_iter 方法,这种调用方式跟 values.into_iter() 是等价的。
同时我们使用了 loop 循环配合 next 方法来遍历迭代器中的元素,当迭代器返回 None 时,跳出循环。
IntoIterator 特征
其实有一个细节,由于 Vec 动态数组实现了 IntoIterator 特征,因此可以通过 into_iter 将其转换为迭代器,那如果本身就是一个迭代器,该怎么办?实际上,迭代器自身也实现了 IntoIterator,标准库早就帮我们考虑好了:
#![allow(unused)] fn main() { impl<I: Iterator> IntoIterator for I { type Item = I::Item; type IntoIter = I; #[inline] fn into_iter(self) -> I { self } } }
最终你完全可以写出这样的奇怪代码:
fn main() { let values = vec![1, 2, 3]; for v in values.into_iter().into_iter().into_iter() { println!("{}",v) } }
into_iter, iter, iter_mut
在之前的代码中,我们统一使用了 into_iter 的方式将数组转化为迭代器,除此之外,还有 iter 和 iter_mut,聪明的读者应该大概能猜到这三者的区别:
into_iter会夺走所有权iter是借用iter_mut是可变借用
其实如果以后见多识广了,你会发现这种问题一眼就能看穿,into_ 之类的,都是拿走所有权,_mut 之类的都是可变借用,剩下的就是不可变借用。
使用一段代码来解释下:
fn main() { let values = vec![1, 2, 3]; for v in values.into_iter() { println!("{}", v) } // 下面的代码将报错,因为 values 的所有权在上面 `for` 循环中已经被转移走 // println!("{:?}",values); let values = vec![1, 2, 3]; let _values_iter = values.iter(); // 不会报错,因为 values_iter 只是借用了 values 中的元素 println!("{:?}", values); let mut values = vec![1, 2, 3]; // 对 values 中的元素进行可变借用 let mut values_iter_mut = values.iter_mut(); // 取出第一个元素,并修改为0 if let Some(v) = values_iter_mut.next() { *v = 0; } // 输出[0, 2, 3] println!("{:?}", values); }
具体解释在代码注释中,就不再赘述,不过有两点需要注意的是:
.iter()方法实现的迭代器,调用next方法返回的类型是Some(&T).iter_mut()方法实现的迭代器,调用next方法返回的类型是Some(&mut T),因此在if let Some(v) = values_iter_mut.next()中,v的类型是&mut i32,最终我们可以通过*v = 0的方式修改其值
Iterator 和 IntoIterator 的区别
这两个其实还蛮容易搞混的,但我们只需要记住,Iterator 就是迭代器特征,只有实现了它才能称为迭代器,才能调用 next。
而 IntoIterator 强调的是某一个类型如果实现了该特征,它可以通过 into_iter,iter 等方法变成一个迭代器。
消费者与适配器
消费者是迭代器上的方法,它会消费掉迭代器中的元素,然后返回其类型的值,这些消费者都有一个共同的特点:在它们的定义中,都依赖 next 方法来消费元素,因此这也是为什么迭代器要实现 Iterator 特征,而该特征必须要实现 next 方法的原因。
消费者适配器
只要迭代器上的某个方法 A 在其内部调用了 next 方法,那么 A 就被称为消费性适配器:因为 next 方法会消耗掉迭代器上的元素,所以方法 A 的调用也会消耗掉迭代器上的元素。
其中一个例子是 sum 方法,它会拿走迭代器的所有权,然后通过不断调用 next 方法对里面的元素进行求和:
fn main() { let v1 = vec![1, 2, 3]; let v1_iter = v1.iter(); let total: i32 = v1_iter.sum(); assert_eq!(total, 6); // v1_iter 是借用了 v1,因此 v1 可以照常使用 println!("{:?}",v1); // 以下代码会报错,因为 `sum` 拿到了迭代器 `v1_iter` 的所有权 // println!("{:?}",v1_iter); }
如代码注释中所说明的:在使用 sum 方法后,我们将无法再使用 v1_iter,因为 sum 拿走了该迭代器的所有权:
#![allow(unused)] fn main() { fn sum<S>(self) -> S where Self: Sized, S: Sum<Self::Item>, { Sum::sum(self) } }
从 sum 源码中也可以清晰看出,self 类型的方法参数拿走了所有权。
迭代器适配器
既然消费者适配器是消费掉迭代器,然后返回一个值。那么迭代器适配器,顾名思义,会返回一个新的迭代器,这是实现链式方法调用的关键:v.iter().map().filter()...。
与消费者适配器不同,迭代器适配器是惰性的,意味着你需要一个消费者适配器来收尾,最终将迭代器转换成一个具体的值:
#![allow(unused)] fn main() { let v1: Vec<i32> = vec![1, 2, 3]; v1.iter().map(|x| x + 1); }
运行后输出:
warning: unused `Map` that must be used
--> src/main.rs:4:5
|
4 | v1.iter().map(|x| x + 1);
| ^^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: `#[warn(unused_must_use)]` on by default
= note: iterators are lazy and do nothing unless consumed // 迭代器 map 是惰性的,这里不产生任何效果
如上述中文注释所说,这里的 map 方法是一个迭代者适配器,它是惰性的,不产生任何行为,因此我们还需要一个消费者适配器进行收尾:
#![allow(unused)] fn main() { let v1: Vec<i32> = vec![1, 2, 3]; let v2: Vec<_> = v1.iter().map(|x| x + 1).collect(); assert_eq!(v2, vec![2, 3, 4]); }
collect
上面代码中,使用了 collect 方法,该方法就是一个消费者适配器,使用它可以将一个迭代器中的元素收集到指定类型中,这里我们为 v2 标注了 Vec<_> 类型,就是为了告诉 collect:请把迭代器中的元素消费掉,然后把值收集成 Vec<_> 类型,至于为何使用 _,因为编译器会帮我们自动推导。
为何 collect 在消费时要指定类型?是因为该方法其实很强大,可以收集成多种不同的集合类型,Vec<T> 仅仅是其中之一,因此我们必须显式的告诉编译器我们想要收集成的集合类型。
还有一点值得注意,map 会对迭代器中的每一个值进行一系列操作,然后把该值转换成另外一个新值,该操作是通过闭包 |x| x + 1 来完成:最终迭代器中的每个值都增加了 1,从 [1, 2, 3] 变为 [2, 3, 4]。
再来看看如何使用 collect 收集成 HashMap 集合:
use std::collections::HashMap; fn main() { let names = ["sunface", "sunfei"]; let ages = [18, 18]; let folks: HashMap<_, _> = names.into_iter().zip(ages.into_iter()).collect(); println!("{:?}",folks); }
zip 是一个迭代器适配器,它的作用就是将两个迭代器的内容压缩到一起,形成 Iterator<Item=(ValueFromA, ValueFromB)> 这样的新的迭代器,在此处就是形如 [(name1, age1), (name2, age2)] 的迭代器。
然后再通过 collect 将新迭代器中(K, V) 形式的值收集成 HashMap<K, V>,同样的,这里必须显式声明类型,然后 HashMap 内部的 KV 类型可以交给编译器去推导,最终编译器会推导出 HashMap<&str, i32>,完全正确!
闭包作为适配器参数
之前的 map 方法中,我们使用闭包来作为迭代器适配器的参数,它最大的好处不仅在于可以就地实现迭代器中元素的处理,还在于可以捕获环境值:
#![allow(unused)] fn main() { struct Shoe { size: u32, style: String, } fn shoes_in_size(shoes: Vec<Shoe>, shoe_size: u32) -> Vec<Shoe> { shoes.into_iter().filter(|s| s.size == shoe_size).collect() } }
filter 是迭代器适配器,用于对迭代器中的每个值进行过滤。 它使用闭包作为参数,该闭包的参数 s 是来自迭代器中的值,然后使用 s 跟外部环境中的 shoe_size 进行比较,若相等,则在迭代器中保留 s 值,若不相等,则从迭代器中剔除 s 值,最终通过 collect 收集为 Vec<Shoe> 类型。
实现 Iterator 特征
之前的内容我们一直基于数组来创建迭代器,实际上,不仅仅是数组,基于其它集合类型一样可以创建迭代器,例如 HashMap。 你也可以创建自己的迭代器 —— 只要为自定义类型实现 Iterator 特征即可。
首先,创建一个计数器:
#![allow(unused)] fn main() { struct Counter { count: u32, } impl Counter { fn new() -> Counter { Counter { count: 0 } } } }
我们为计数器 Counter 实现了一个关联函数 new,用于创建新的计数器实例。下面我们继续为计数器实现 Iterator 特征:
#![allow(unused)] fn main() { impl Iterator for Counter { type Item = u32; fn next(&mut self) -> Option<Self::Item> { if self.count < 5 { self.count += 1; Some(self.count) } else { None } } } }
首先,将该特征的关联类型设置为 u32,由于我们的计数器保存的 count 字段就是 u32 类型, 因此在 next 方法中,最后返回的是实际上是 Option<u32> 类型。
每次调用 next 方法,都会让计数器的值加一,然后返回最新的计数值,一旦计数大于 5,就返回 None。
最后,使用我们新建的 Counter 进行迭代:
#![allow(unused)] fn main() { let mut counter = Counter::new(); assert_eq!(counter.next(), Some(1)); assert_eq!(counter.next(), Some(2)); assert_eq!(counter.next(), Some(3)); assert_eq!(counter.next(), Some(4)); assert_eq!(counter.next(), Some(5)); assert_eq!(counter.next(), None); }
实现 Iterator 特征的其它方法
可以看出,实现自己的迭代器非常简单,但是 Iterator 特征中,不仅仅是只有 next 一个方法,那为什么我们只需要实现它呢?因为其它方法都具有默认实现,所以无需像 next 这样手动去实现,而且这些默认实现的方法其实都是基于 next 方法实现的。
下面的代码演示了部分方法的使用:
#![allow(unused)] fn main() { let sum: u32 = Counter::new() .zip(Counter::new().skip(1)) .map(|(a, b)| a * b) .filter(|x| x % 3 == 0) .sum(); assert_eq!(18, sum); }
其中 zip,map,filter 是迭代器适配器:
zip把两个迭代器合并成一个迭代器,新迭代器中,每个元素都是一个元组,由之前两个迭代器的元素组成。例如将形如[1, 2, 3, 4, 5]和[2, 3, 4, 5]的迭代器合并后,新的迭代器形如[(1, 2),(2, 3),(3, 4),(4, 5)]map是将迭代器中的值经过映射后,转换成新的值[2, 6, 12, 20]filter对迭代器中的元素进行过滤,若闭包返回true则保留元素[6, 12],反之剔除
而 sum 是消费者适配器,对迭代器中的所有元素求和,最终返回一个 u32 值 18。
enumerate
在之前的流程控制章节,针对 for 循环,我们提供了一种方法可以获取迭代时的索引:
#![allow(unused)] fn main() { let v = vec![1u64, 2, 3, 4, 5, 6]; for (i,v) in v.iter().enumerate() { println!("第{}个值是{}",i,v) } }
相信当时,很多读者还是很迷茫的,不知道为什么要这么复杂才能获取到索引,学习本章节后,相信你有了全新的理解,首先 v.iter() 创建迭代器,其次
调用 Iterator 特征上的方法 enumerate,该方法产生一个新的迭代器,其中每个元素均是元组 (索引,值)。
因为 enumerate 是迭代器适配器,因此我们可以对它返回的迭代器调用其它 Iterator 特征方法:
#![allow(unused)] fn main() { let v = vec![1u64, 2, 3, 4, 5, 6]; let val = v.iter() .enumerate() // 每两个元素剔除一个 // [1, 3, 5] .filter(|&(idx, _)| idx % 2 == 0) .map(|(idx, val)| val) // 累加 1+3+5 = 9 .fold(0u64, |sum, acm| sum + acm); println!("{}", val); }
迭代器的性能
前面提到,要完成集合遍历,既可以使用 for 循环也可以使用迭代器,那么二者之间该怎么选择呢,性能有多大差距呢?
理论分析不会有结果,直接测试最为靠谱:
#![allow(unused)] #![feature(test)] fn main() { extern crate rand; extern crate test; fn sum_for(x: &[f64]) -> f64 { let mut result: f64 = 0.0; for i in 0..x.len() { result += x[i]; } result } fn sum_iter(x: &[f64]) -> f64 { x.iter().sum::<f64>() } #[cfg(test)] mod bench { use test::Bencher; use rand::{Rng,thread_rng}; use super::*; const LEN: usize = 1024*1024; fn rand_array(cnt: u32) -> Vec<f64> { let mut rng = thread_rng(); (0..cnt).map(|_| rng.gen::<f64>()).collect() } #[bench] fn bench_for(b: &mut Bencher) { let samples = rand_array(LEN as u32); b.iter(|| { sum_for(&samples) }) } #[bench] fn bench_iter(b: &mut Bencher) { let samples = rand_array(LEN as u32); b.iter(|| { sum_iter(&samples) }) } } }
上面的代码对比了 for 循环和迭代器 iterator 完成同样的求和任务的性能对比,可以看到迭代器还要更快一点。
test bench::bench_for ... bench: 998,331 ns/iter (+/- 36,250)
test bench::bench_iter ... bench: 983,858 ns/iter (+/- 44,673)
迭代器是 Rust 的 零成本抽象(zero-cost abstractions)之一,意味着抽象并不会引入运行时开销,这与 Bjarne Stroustrup(C++ 的设计和实现者)在 Foundations of C++(2012) 中所定义的 零开销(zero-overhead)如出一辙:
In general, C++ implementations obey the zero-overhead principle: What you don’t use, you don’t pay for. And further: What you do use, you couldn’t hand code any better.
一般来说,C++的实现遵循零开销原则:没有使用时,你不必为其买单。 更进一步说,需要使用时,你也无法写出更优的代码了。 (翻译一下:用就完事了)
总之,迭代器是 Rust 受函数式语言启发而提供的高级语言特性,可以写出更加简洁、逻辑清晰的代码。编译器还可以通过循环展开(Unrolling)、向量化、消除边界检查等优化手段,使得迭代器和 for 循环都有极为高效的执行效率。
所以请放心大胆的使用迭代器,在获得更高的表达力的同时,也不会导致运行时的损失,何乐而不为呢!
学习其它方法
迭代器用的好不好,就在于你是否掌握了它的常用方法,且能活学活用,因此多多看看标准库是有好处的,只有知道有什么方法,在需要的时候你才能知道该用什么,就和算法学习一样。
同时,本书在后续章节还提供了对迭代器常用方法的深入讲解,方便大家学习和查阅。
深入类型
Rust 是强类型语言,同时也是强安全语言,这些特性导致了 Rust 的类型注定比一般语言要更深入也更困难。
本章将深入讲解一些进阶的 Rust 类型以及类型转换,希望大家喜欢。
深入 Rust 类型
弱弱地、不负责任地说,Rust 的学习难度之恶名,可能有一半来源于 Rust 的类型系统,而其中一半的一半则来自于本章节的内容。在本章,我们将重点学习如何创建自定义类型,以及了解何为动态大小的类型。
newtype
何为 newtype?简单来说,就是使用元组结构体的方式将已有的类型包裹起来:struct Meters(u32);,那么此处 Meters 就是一个 newtype。
为何需要 newtype?Rust 这多如繁星的 Old 类型满足不了我们吗?这是因为:
- 自定义类型可以让我们给出更有意义和可读性的类型名,例如与其使用
u32作为距离的单位类型,我们可以使用Meters,它的可读性要好得多 - 对于某些场景,只有
newtype可以很好地解决 - 隐藏内部类型的细节
一箩筐的理由~~ 让我们先从第二点讲起。
为外部类型实现外部特征
在之前的章节中,我们有讲过,如果在外部类型上实现外部特征必须使用 newtype 的方式,否则你就得遵循孤儿规则:要为类型 A 实现特征 T,那么 A 或者 T 必须至少有一个在当前的作用范围内。
例如,如果想使用 println!("{}", v) 的方式去格式化输出一个动态数组 Vec,以期给用户提供更加清晰可读的内容,那么就需要为 Vec 实现 Display 特征,但是这里有一个问题: Vec 类型定义在标准库中,Display 亦然,这时就可以祭出大杀器 newtype 来解决:
use std::fmt; struct Wrapper(Vec<String>); impl fmt::Display for Wrapper { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "[{}]", self.0.join(", ")) } } fn main() { let w = Wrapper(vec![String::from("hello"), String::from("world")]); println!("w = {}", w); }
如上所示,使用元组结构体语法 struct Wrapper(Vec<String>) 创建了一个 newtype Wrapper,然后为它实现 Display 特征,最终实现了对 Vec 动态数组的格式化输出。
更好的可读性及类型异化
首先,更好的可读性不等于更少的代码(如果你学过 Scala,相信会深有体会),其次下面的例子只是一个示例,未必能体现出更好的可读性:
use std::ops::Add; use std::fmt; struct Meters(u32); impl fmt::Display for Meters { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "目标地点距离你{}米", self.0) } } impl Add for Meters { type Output = Self; fn add(self, other: Meters) -> Self { Self(self.0 + other.0) } } fn main() { let d = calculate_distance(Meters(10), Meters(20)); println!("{}", d); } fn calculate_distance(d1: Meters, d2: Meters) -> Meters { d1 + d2 }
上面代码创建了一个 newtype Meters,为其实现 Display 和 Add 特征,接着对两个距离进行求和计算,最终打印出该距离:
目标地点距离你30米
事实上,除了可读性外,还有一个极大的优点:如果给 calculate_distance 传一个其它的类型,例如 struct MilliMeters(u32);,该代码将无法编译。尽管 Meters 和 MilliMeters 都是对 u32 类型的简单包装,但是它们是不同的类型!
隐藏内部类型的细节
众所周知,Rust 的类型有很多自定义的方法,假如我们把某个类型传给了用户,但是又不想用户调用这些方法,就可以使用 newtype:
struct Meters(u32); fn main() { let i: u32 = 2; assert_eq!(i.pow(2), 4); let n = Meters(i); // 下面的代码将报错,因为`Meters`类型上没有`pow`方法 // assert_eq!(n.pow(2), 4); }
不过需要偷偷告诉你的是,这种方式实际上是掩耳盗铃,因为用户依然可以通过 n.0.pow(2) 的方式来调用内部类型的方法 :)
类型别名(Type Alias)
除了使用 newtype,我们还可以使用一个更传统的方式来创建新类型:类型别名
#![allow(unused)] fn main() { type Meters = u32 }
嗯,不得不说,类型别名的方式看起来比 newtype 顺眼的多,而且跟其它语言的使用方式几乎一致,但是:
类型别名并不是一个独立的全新的类型,而是某一个类型的别名,因此编译器依然会把 Meters 当 u32 来使用:
#![allow(unused)] fn main() { type Meters = u32; let x: u32 = 5; let y: Meters = 5; println!("x + y = {}", x + y); }
上面的代码将顺利编译通过,但是如果你使用 newtype 模式,该代码将无情报错,简单做个总结:
- 类型别名仅仅是别名,只是为了让可读性更好,并不是全新的类型,
newtype才是! - 类型别名无法实现为外部类型实现外部特征等功能,而
newtype可以
类型别名除了让类型可读性更好,还能减少模版代码的使用:
#![allow(unused)] fn main() { let f: Box<dyn Fn() + Send + 'static> = Box::new(|| println!("hi")); fn takes_long_type(f: Box<dyn Fn() + Send + 'static>) { // --snip-- } fn returns_long_type() -> Box<dyn Fn() + Send + 'static> { // --snip-- } }
f 是一个令人眼花缭乱的类型 Box<dyn Fn() + Send + 'static>,如果仔细看,会发现其实只有一个 Send 特征不认识,Send 是什么在这里不重要,你只需理解,f 就是一个 Box<dyn T> 类型的特征对象,实现了 Fn() 和 Send 特征,同时生命周期为 'static。
因为 f 的类型贼长,导致了后面我们在使用它时,到处都充斥这些不太优美的类型标注,好在类型别名可解君忧:
#![allow(unused)] fn main() { type Thunk = Box<dyn Fn() + Send + 'static>; let f: Thunk = Box::new(|| println!("hi")); fn takes_long_type(f: Thunk) { // --snip-- } fn returns_long_type() -> Thunk { // --snip-- } }
Bang!是不是?!立刻大幅简化了我们的使用。喝着奶茶、哼着歌、我写起代码撩起妹,何其快哉!
在标准库中,类型别名应用最广的就是简化 Result<T, E> 枚举。
例如在 std::io 库中,它定义了自己的 Error 类型:std::io::Error,那么如果要使用该 Result 就要用这样的语法:std::result::Result<T, std::io::Error>;,想象一下代码中充斥着这样的东东是一种什么感受?颤抖吧。。。
由于使用 std::io 库时,它的所有错误类型都是 std::io::Error,那么我们完全可以把该错误对用户隐藏起来,只在内部使用即可,因此就可以使用类型别名来简化实现:
#![allow(unused)] fn main() { type Result<T> = std::result::Result<T, std::io::Error>; }
Bingo,这样一来,其它库只需要使用 std::io::Result<T> 即可替代冗长的 std::result::Result<T, std::io::Error> 类型。
更香的是,由于它只是别名,因此我们可以用它来调用真实类型的所有方法,甚至包括 ? 符号!
!永不返回类型
在函数那章,曾经介绍过 ! 类型:! 用来说明一个函数永不返回任何值,当时可能体会不深,没事,在学习了更多手法后,保证你有全新的体验:
fn main() { let i = 2; let v = match i { 0..=3 => i, _ => println!("不合规定的值:{}", i) }; }
上面函数,会报出一个编译错误:
error[E0308]: `match` arms have incompatible types // match的分支类型不同
--> src/main.rs:5:13
|
3 | let v = match i {
| _____________-
4 | | 0..3 => i,
| | - this is found to be of type `{integer}` // 该分支返回整数类型
5 | | _ => println!("不合规定的值:{}", i)
| | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected integer, found `()` // 该分支返回()单元类型
6 | | };
| |_____- `match` arms have incompatible types
原因很简单: 要赋值给 v,就必须保证 match 的各个分支返回的值是同一个类型,但是上面一个分支返回数值、另一个分支返回元类型 (),自然会出错。
既然 println 不行,那再试试 panic
fn main() { let i = 2; let v = match i { 0..=3 => i, _ => panic!("不合规定的值:{}", i) }; }
神奇的事发生了,此处 panic 竟然通过了编译。难道这两个宏拥有不同的返回类型?
你猜的没错:panic 的返回值是 !,代表它决不会返回任何值,既然没有任何返回值,那自然不会存在分支类型不匹配的情况。
Sized 和不定长类型 DST
在 Rust 中类型有多种抽象的分类方式,例如本书之前章节的:基本类型、集合类型、复合类型等。再比如说,如果从编译器何时能获知类型大小的角度出发,可以分成两类:
- 定长类型( sized ),这些类型的大小在编译时是已知的
- 不定长类型( unsized ),与定长类型相反,它的大小只有到了程序运行时才能动态获知,这种类型又被称之为 DST
首先,我们来深入看看何为 DST。
动态大小类型 DST
读者大大们之前学过的几乎所有类型,都是固定大小的类型,包括集合 Vec、String 和 HashMap 等,而动态大小类型刚好与之相反:编译器无法在编译期得知该类型值的大小,只有到了程序运行时,才能动态获知。对于动态类型,我们使用 DST(dynamically sized types)或者 unsized 类型来称呼它。
上述的这些集合虽然底层数据可动态变化,感觉像是动态大小的类型。但是实际上,这些底层数据只是保存在堆上,在栈中还存有一个引用类型,该引用包含了集合的内存地址、元素数目、分配空间信息,通过这些信息,编译器对于该集合的实际大小了若指掌,最最重要的是:栈上的引用类型是固定大小的,因此它们依然是固定大小的类型。
正因为编译器无法在编译期获知类型大小,若你试图在代码中直接使用 DST 类型,将无法通过编译。
现在给你一个挑战:想出几个 DST 类型。俺厚黑地说一句,估计大部分人都想不出这样的一个类型,就连我,如果不是查询着资料在写,估计一时半会儿也想不到一个。
先来看一个最直白的:
试图创建动态大小的数组
#![allow(unused)] fn main() { fn my_function(n: usize) { let array = [123; n]; } }
以上代码就会报错(错误输出的内容并不是因为 DST,但根本原因是类似的),因为 n 在编译期无法得知,而数组类型的一个组成部分就是长度,长度变为动态的,自然类型就变成了 unsized 。
切片
切片也是一个典型的 DST 类型,具体详情参见另一篇文章: 易混淆的切片和切片引用。
str
考虑一下这个类型:str,感觉有点眼生?是的,它既不是 String 动态字符串,也不是 &str 字符串切片,而是一个 str。它是一个动态类型,同时还是 String 和 &str 的底层数据类型。 由于 str 是动态类型,因此它的大小直到运行期才知道,下面的代码会因此报错:
#![allow(unused)] fn main() { // error let s1: str = "Hello there!"; let s2: str = "How's it going?"; // ok let s3: &str = "on?" }
Rust 需要明确地知道一个特定类型的值占据了多少内存空间,同时该类型的所有值都必须使用相同大小的内存。如果 Rust 允许我们使用这种动态类型,那么这两个 str 值就需要占用同样大小的内存,这显然是不现实的: s1 占用了 12 字节,s2 占用了 15 字节,总不至于为了满足同样的内存大小,用空白字符去填补字符串吧?
所以,我们只有一条路走,那就是给它们一个固定大小的类型:&str。那么为何字符串切片 &str 就是固定大小呢?因为它的引用存储在栈上,具有固定大小(类似指针),同时它指向的数据存储在堆中,也是已知的大小,再加上 &str 引用中包含有堆上数据内存地址、长度等信息,因此最终可以得出字符串切片是固定大小类型的结论。
与 &str 类似,String 字符串也是固定大小的类型。
正是因为 &str 的引用有了底层堆数据的明确信息,它才是固定大小类型。假设如果它没有这些信息呢?那它也将变成一个动态类型。因此,将动态数据固定化的秘诀就是使用引用指向这些动态数据,然后在引用中存储相关的内存位置、长度等信息。
特征对象
#![allow(unused)] fn main() { fn foobar_1(thing: &dyn MyThing) {} // OK fn foobar_2(thing: Box<dyn MyThing>) {} // OK fn foobar_3(thing: MyThing) {} // ERROR! }
如上所示,只能通过引用或 Box 的方式来使用特征对象,直接使用将报错!
总结:只能间接使用的 DST
Rust 中常见的 DST 类型有: str、[T]、dyn Trait,它们都无法单独被使用,必须要通过引用或者 Box 来间接使用 。
我们之前已经见过,使用 Box 将一个没有固定大小的特征变成一个有固定大小的特征对象,那能否故技重施,将 str 封装成一个固定大小类型?留个悬念先,我们来看看 Sized 特征。
Sized 特征
既然动态类型的问题这么大,那么在使用泛型时,Rust 如何保证我们的泛型参数是固定大小的类型呢?例如以下泛型函数:
#![allow(unused)] fn main() { fn generic<T>(t: T) { // --snip-- } }
该函数很简单,就一个泛型参数 T,那么如何保证 T 是固定大小的类型?仔细回想下,貌似在之前的课程章节中,我们也没有做过任何事情去做相关的限制,那 T 怎么就成了固定大小的类型了?奥秘在于编译器自动帮我们加上了 Sized 特征约束:
#![allow(unused)] fn main() { fn generic<T: Sized>(t: T) { // --snip-- } }
在上面,Rust 自动添加的特征约束 T: Sized,表示泛型函数只能用于一切实现了 Sized 特征的类型上,而所有在编译时就能知道其大小的类型,都会自动实现 Sized 特征,例如。。。。也没啥好例如的,你能想到的几乎所有类型都实现了 Sized 特征,除了上面那个坑坑的 str,哦,还有特征。
每一个特征都是一个可以通过名称来引用的动态大小类型。因此如果想把特征作为具体的类型来传递给函数,你必须将其转换成一个特征对象:诸如 &dyn Trait 或者 Box<dyn Trait> (还有 Rc<dyn Trait>)这些引用类型。
现在还有一个问题:假如想在泛型函数中使用动态数据类型怎么办?可以使用 ?Sized 特征(不得不说这个命名方式很 Rusty,竟然有点幽默):
#![allow(unused)] fn main() { fn generic<T: ?Sized>(t: &T) { // --snip-- } }
?Sized 特征用于表明类型 T 既有可能是固定大小的类型,也可能是动态大小的类型。还有一点要注意的是,函数参数类型从 T 变成了 &T,因为 T 可能是动态大小的,因此需要用一个固定大小的指针(引用)来包裹它。
Box<str>
在结束前,再来看看之前遗留的问题:使用 Box 可以将一个动态大小的特征变成一个具有固定大小的特征对象,能否故技重施,将 str 封装成一个固定大小类型?
先回想下,章节前面的内容介绍过该如何把一个动态大小类型转换成固定大小的类型: 使用引用指向这些动态数据,然后在引用中存储相关的内存位置、长度等信息。
好的,根据这个,我们来一起推测。首先,Box<str> 使用了一个引用来指向 str,嗯,满足了第一个条件。但是第二个条件呢?Box 中有该 str 的长度信息吗?显然是 No。那为什么特征就可以变成特征对象?其实这个还蛮复杂的,简单来说,对于特征对象,编译器无需知道它具体是什么类型,只要知道它能调用哪几个方法即可,因此编译器帮我们实现了剩下的一切。
来验证下我们的推测:
fn main() { let s1: Box<str> = Box::new("Hello there!" as str); }
报错如下:
error[E0277]: the size for values of type `str` cannot be known at compilation time
--> src/main.rs:2:24
|
2 | let s1: Box<str> = Box::new("Hello there!" as str);
| ^^^^^^^^ doesn't have a size known at compile-time
|
= help: the trait `Sized` is not implemented for `str`
= note: all function arguments must have a statically known size
提示得很清晰,不知道 str 的大小,因此无法使用这种语法进行 Box 进装,但是你可以这么做:
#![allow(unused)] fn main() { let s1: Box<str> = "Hello there!".into(); }
主动转换成 str 的方式不可行,但是可以让编译器来帮我们完成,只要告诉它我们需要的类型即可。
整数转换为枚举
在 Rust 中,从枚举到整数的转换很容易,但是反过来,就没那么容易,甚至部分实现还挺邪恶, 例如使用transmute。
一个真实场景的需求
在实际场景中,从枚举到整数的转换有时还是非常需要的,例如你有一个枚举类型,然后需要从外面传入一个整数,用于控制后续的流程走向,此时就需要用整数去匹配相应的枚举(你也可以用整数匹配整数-, -,看看会不会被喷)。
既然有了需求,剩下的就是看看该如何实现,这篇文章的水远比你想象的要深,且看八仙过海各显神通。
C 语言的实现
对于 C 语言来说,万物皆邪恶,因此我们不讨论安全,只看实现,不得不说很简洁:
#include <stdio.h>
enum atomic_number {
HYDROGEN = 1,
HELIUM = 2,
// ...
IRON = 26,
};
int main(void)
{
enum atomic_number element = 26;
if (element == IRON) {
printf("Beware of Rust!\n");
}
return 0;
}
但是在 Rust 中,以下代码:
enum MyEnum { A = 1, B, C, } fn main() { // 将枚举转换成整数,顺利通过 let x = MyEnum::C as i32; // 将整数转换为枚举,失败 match x { MyEnum::A => {} MyEnum::B => {} MyEnum::C => {} _ => {} } }
就会报错: MyEnum::A => {} mismatched types, expected i32, found enum MyEnum。
使用三方库
首先可以想到的肯定是三方库,毕竟 Rust 的生态目前已经发展的很不错,类似的需求总是有的,这里我们先使用num-traits和num-derive来试试。
在Cargo.toml中引入:
[dependencies]
num-traits = "0.2.14"
num-derive = "0.3.3"
代码如下:
use num_derive::FromPrimitive; use num_traits::FromPrimitive; #[derive(FromPrimitive)] enum MyEnum { A = 1, B, C, } fn main() { let x = 2; match FromPrimitive::from_i32(x) { Some(MyEnum::A) => println!("Got A"), Some(MyEnum::B) => println!("Got B"), Some(MyEnum::C) => println!("Got C"), None => println!("Couldn't convert {}", x), } }
除了上面的库,还可以使用一个较新的库: num_enums。
TryFrom + 宏
在 Rust 1.34 后,可以实现TryFrom特征来做转换:
#![allow(unused)] fn main() { use std::convert::TryFrom; impl TryFrom<i32> for MyEnum { type Error = (); fn try_from(v: i32) -> Result<Self, Self::Error> { match v { x if x == MyEnum::A as i32 => Ok(MyEnum::A), x if x == MyEnum::B as i32 => Ok(MyEnum::B), x if x == MyEnum::C as i32 => Ok(MyEnum::C), _ => Err(()), } } } }
以上代码定义了从i32到MyEnum的转换,接着就可以使用TryInto来实现转换:
use std::convert::TryInto; fn main() { let x = MyEnum::C as i32; match x.try_into() { Ok(MyEnum::A) => println!("a"), Ok(MyEnum::B) => println!("b"), Ok(MyEnum::C) => println!("c"), Err(_) => eprintln!("unknown number"), } }
但是上面的代码有个问题,你需要为每个枚举成员都实现一个转换分支,非常麻烦。好在可以使用宏来简化,自动根据枚举的定义来实现TryFrom特征:
#![allow(unused)] fn main() { #[macro_export] macro_rules! back_to_enum { ($(#[$meta:meta])* $vis:vis enum $name:ident { $($(#[$vmeta:meta])* $vname:ident $(= $val:expr)?,)* }) => { $(#[$meta])* $vis enum $name { $($(#[$vmeta])* $vname $(= $val)?,)* } impl std::convert::TryFrom<i32> for $name { type Error = (); fn try_from(v: i32) -> Result<Self, Self::Error> { match v { $(x if x == $name::$vname as i32 => Ok($name::$vname),)* _ => Err(()), } } } } } back_to_enum! { enum MyEnum { A = 1, B, C, } } }
邪恶之王 std::mem::transmute
这个方法原则上并不推荐,但是有其存在的意义,如果要使用,你需要清晰的知道自己为什么使用。
在之前的类型转换章节,我们提到过非常邪恶的transmute转换,其实,当你知道数值一定不会超过枚举的范围时(例如枚举成员对应 1,2,3,传入的整数也在这个范围内),就可以使用这个方法完成变形。
最好使用#[repr(..)]来控制底层类型的大小,免得本来需要 i32,结果传入 i64,最终内存无法对齐,产生奇怪的结果
#[repr(i32)] enum MyEnum { A = 1, B, C } fn main() { let x = MyEnum::C; let y = x as i32; let z: MyEnum = unsafe { std::mem::transmute(y) }; // match the enum that came from an int match z { MyEnum::A => { println!("Found A"); } MyEnum::B => { println!("Found B"); } MyEnum::C => { println!("Found C"); } } }
既然是邪恶之王,当然得有真本事,无需标准库、也无需 unstable 的 Rust 版本,我们就完成了转换!awesome!??
总结
本文列举了常用(其实差不多也是全部了,还有一个 unstable 特性没提到)的从整数转换为枚举的方式,推荐度按照出现的先后顺序递减。
但是推荐度最低,不代表它就没有出场的机会,只要使用边界清晰,一样可以大放光彩,例如最后的transmute函数.
智能指针
在各个编程语言中,指针的概念几乎都是相同的:指针是一个包含了内存地址的变量,该内存地址引用或者指向了另外的数据。
在 Rust 中,最常见的指针类型是引用,引用通过 & 符号表示。不同于其它语言,引用在 Rust 中被赋予了更深层次的含义,那就是:借用其它变量的值。引用本身很简单,除了指向某个值外并没有其它的功能,也不会造成性能上的额外损耗,因此是 Rust 中使用最多的指针类型。
而智能指针则不然,它虽然也号称指针,但是它是一个复杂的家伙:通过比引用更复杂的数据结构,包含比引用更多的信息,例如元数据,当前长度,最大可用长度等。总之,Rust 的智能指针并不是独创,在 C++ 或者其他语言中也存在相似的概念。
Rust 标准库中定义的那些智能指针,虽重但强,可以提供比引用更多的功能特性,例如本章将讨论的引用计数智能指针。该智能指针允许你同时拥有同一个数据的多个所有权,它会跟踪每一个所有者并进行计数,当所有的所有者都归还后,该智能指针及指向的数据将自动被清理释放。
引用和智能指针的另一个不同在于前者仅仅是借用了数据,而后者往往可以拥有它们指向的数据,然后再为其它人提供服务。
在之前的章节中,实际上我们已经见识过多种智能指针,例如动态字符串 String 和动态数组 Vec,它们的数据结构中不仅仅包含了指向底层数据的指针,还包含了当前长度、最大长度等信息,其中 String 智能指针还提供了一种担保信息:所有的数据都是合法的 UTF-8 格式。
智能指针往往是基于结构体实现,它与我们自定义的结构体最大的区别在于它实现了 Deref 和 Drop 特征:
Deref可以让智能指针像引用那样工作,这样你就可以写出同时支持智能指针和引用的代码,例如*TDrop允许你指定智能指针超出作用域后自动执行的代码,例如做一些数据清除等收尾工作
智能指针在 Rust 中很常见,我们在本章不会全部讲解,而是挑选几个最常用、最有代表性的进行讲解:
Box<T>,可以将值分配到堆上Rc<T>,引用计数类型,允许多所有权存在Ref<T>和RefMut<T>,允许将借用规则检查从编译期移动到运行期进行
Box<T> 堆对象分配
关于作者帅不帅,估计争议还挺多的,但是如果说 Box<T> 是不是 Rust 中最常见的智能指针,那估计没有任何争议。因为 Box<T> 允许你将一个值分配到堆上,然后在栈上保留一个智能指针指向堆上的数据。
之前我们在所有权章节简单讲过堆栈的概念,这里再补充一些。
Rust 中的堆栈
高级语言 Python/Java 等往往会弱化堆栈的概念,但是要用好 C/C++/Rust,就必须对堆栈有深入的了解,原因是两者的内存管理方式不同:前者有 GC 垃圾回收机制,因此无需你去关心内存的细节。
栈内存从高位地址向下增长,且栈内存是连续分配的,一般来说操作系统对栈内存的大小都有限制,因此 C 语言中无法创建任意长度的数组。在 Rust 中,main 线程的栈大小是 8MB,普通线程是 2MB,在函数调用时会在其中创建一个临时栈空间,调用结束后 Rust 会让这个栈空间里的对象自动进入 Drop 流程,最后栈顶指针自动移动到上一个调用栈顶,无需程序员手动干预,因而栈内存申请和释放是非常高效的。
与栈相反,堆上内存则是从低位地址向上增长,堆内存通常只受物理内存限制,而且通常是不连续的,因此从性能的角度看,栈往往比堆更高。
相比其它语言,Rust 堆上对象还有一个特殊之处,它们都拥有一个所有者,因此受所有权规则的限制:当赋值时,发生的是所有权的转移(只需浅拷贝栈上的引用或智能指针即可),例如以下代码:
fn main() { let b = foo("world"); println!("{}", b); } fn foo(x: &str) -> String { let a = "Hello, ".to_string() + x; a }
在 foo 函数中,a 是 String 类型,它其实是一个智能指针结构体,该智能指针存储在函数栈中,指向堆上的字符串数据。当被从 foo 函数转移给 main 中的 b 变量时,栈上的智能指针被复制一份赋予给 b,而底层数据无需发生改变,这样就完成了所有权从 foo 函数内部到 b 的转移。
堆栈的性能
很多人可能会觉得栈的性能肯定比堆高,其实未必。 由于我们在后面的性能专题会专门讲解堆栈的性能问题,因此这里就大概给出结论:
- 小型数据,在栈上的分配性能和读取性能都要比堆上高
- 中型数据,栈上分配性能高,但是读取性能和堆上并无区别,因为无法利用寄存器或 CPU 高速缓存,最终还是要经过一次内存寻址
- 大型数据,只建议在堆上分配和使用
总之,栈的分配速度肯定比堆上快,但是读取速度往往取决于你的数据能不能放入寄存器或 CPU 高速缓存。 因此不要仅仅因为堆上性能不如栈这个印象,就总是优先选择栈,导致代码更复杂的实现。
Box 的使用场景
由于 Box 是简单的封装,除了将值存储在堆上外,并没有其它性能上的损耗。而性能和功能往往是鱼和熊掌,因此 Box 相比其它智能指针,功能较为单一,可以在以下场景中使用它:
- 特意的将数据分配在堆上
- 数据较大时,又不想在转移所有权时进行数据拷贝
- 类型的大小在编译期无法确定,但是我们又需要固定大小的类型时
- 特征对象,用于说明对象实现了一个特征,而不是某个特定的类型
以上场景,我们在本章将一一讲解,后面车速较快,请系好安全带。
使用 Box<T> 将数据存储在堆上
如果一个变量拥有一个数值 let a = 3,那变量 a 必然是存储在栈上的,那如果我们想要 a 的值存储在堆上就需要使用 Box<T>:
fn main() { let a = Box::new(3); println!("a = {}", a); // a = 3 // 下面一行代码将报错 // let b = a + 1; // cannot add `{integer}` to `Box<{integer}>` }
这样就可以创建一个智能指针指向了存储在堆上的 3,并且 a 持有了该指针。在本章的引言中,我们提到了智能指针往往都实现了 Deref 和 Drop 特征,因此:
println!可以正常打印出a的值,是因为它隐式地调用了Deref对智能指针a进行了解引用- 最后一行代码
let b = a + 1报错,是因为在表达式中,我们无法自动隐式地执行Deref解引用操作,你需要使用*操作符let b = *a + 1,来显式的进行解引用 a持有的智能指针将在作用域结束(main函数结束)时,被释放掉,这是因为Box<T>实现了Drop特征
以上的例子在实际代码中其实很少会存在,因为将一个简单的值分配到堆上并没有太大的意义。将其分配在栈上,由于寄存器、CPU 缓存的原因,它的性能将更好,而且代码可读性也更好。
避免栈上数据的拷贝
当栈上数据转移所有权时,实际上是把数据拷贝了一份,最终新旧变量各自拥有不同的数据,因此所有权并未转移。
而堆上则不然,底层数据并不会被拷贝,转移所有权仅仅是复制一份栈中的指针,再将新的指针赋予新的变量,然后让拥有旧指针的变量失效,最终完成了所有权的转移:
fn main() { // 在栈上创建一个长度为1000的数组 let arr = [0;1000]; // 将arr所有权转移arr1,由于 `arr` 分配在栈上,因此这里实际上是直接重新深拷贝了一份数据 let arr1 = arr; // arr 和 arr1 都拥有各自的栈上数组,因此不会报错 println!("{:?}", arr.len()); println!("{:?}", arr1.len()); // 在堆上创建一个长度为1000的数组,然后使用一个智能指针指向它 let arr = Box::new([0;1000]); // 将堆上数组的所有权转移给 arr1,由于数据在堆上,因此仅仅拷贝了智能指针的结构体,底层数据并没有被拷贝 // 所有权顺利转移给 arr1,arr 不再拥有所有权 let arr1 = arr; println!("{:?}", arr1.len()); // 由于 arr 不再拥有底层数组的所有权,因此下面代码将报错 // println!("{:?}", arr.len()); }
从以上代码,可以清晰看出大块的数据为何应该放入堆中,此时 Box 就成为了我们最好的帮手。
将动态大小类型变为 Sized 固定大小类型
Rust 需要在编译时知道类型占用多少空间,如果一种类型在编译时无法知道具体的大小,那么被称为动态大小类型 DST。
其中一种无法在编译时知道大小的类型是递归类型:在类型定义中又使用到了自身,或者说该类型的值的一部分可以是相同类型的其它值,这种值的嵌套理论上可以无限进行下去,所以 Rust 不知道递归类型需要多少空间:
#![allow(unused)] fn main() { enum List { Cons(i32, List), Nil, } }
以上就是函数式语言中常见的 Cons List,它的每个节点包含一个 i32 值,还包含了一个新的 List,因此这种嵌套可以无限进行下去,Rust 认为该类型是一个 DST 类型,并给予报错:
error[E0072]: recursive type `List` has infinite size //递归类型 `List` 拥有无限长的大小
--> src/main.rs:3:1
|
3 | enum List {
| ^^^^^^^^^ recursive type has infinite size
4 | Cons(i32, List),
| ---- recursive without indirection
此时若想解决这个问题,就可以使用我们的 Box<T>:
#![allow(unused)] fn main() { enum List { Cons(i32, Box<List>), Nil, } }
只需要将 List 存储到堆上,然后使用一个智能指针指向它,即可完成从 DST 到 Sized 类型(固定大小类型)的华丽转变。
特征对象
在 Rust 中,想实现不同类型组成的数组只有两个办法:枚举和特征对象,前者限制较多,因此后者往往是最常用的解决办法。
trait Draw { fn draw(&self); } struct Button { id: u32, } impl Draw for Button { fn draw(&self) { println!("这是屏幕上第{}号按钮", self.id) } } struct Select { id: u32, } impl Draw for Select { fn draw(&self) { println!("这个选择框贼难用{}", self.id) } } fn main() { let elems: Vec<Box<dyn Draw>> = vec![Box::new(Button { id: 1 }), Box::new(Select { id: 2 })]; for e in elems { e.draw() } }
以上代码将不同类型的 Button 和 Select 包装成 Draw 特征的特征对象,放入一个数组中,Box<dyn Draw> 就是特征对象。
其实,特征也是 DST 类型,而特征对象在做的就是将 DST 类型转换为固定大小类型。
Box 内存布局
先来看看 Vec<i32> 的内存布局:
#![allow(unused)] fn main() { (stack) (heap) ┌──────┐ ┌───┐ │ vec1 │──→│ 1 │ └──────┘ ├───┤ │ 2 │ ├───┤ │ 3 │ ├───┤ │ 4 │ └───┘ }
之前提到过 Vec 和 String 都是智能指针,从上图可以看出,该智能指针存储在栈中,然后指向堆上的数组数据。
那如果数组中每个元素都是一个 Box 对象呢?来看看 Vec<Box<i32>> 的内存布局:
#![allow(unused)] fn main() { (heap) (stack) (heap) ┌───┐ ┌──────┐ ┌───┐ ┌─→│ 1 │ │ vec2 │──→│B1 │─┘ └───┘ └──────┘ ├───┤ ┌───┐ │B2 │───→│ 2 │ ├───┤ └───┘ │B3 │─┐ ┌───┐ ├───┤ └─→│ 3 │ │B4 │─┐ └───┘ └───┘ │ ┌───┐ └─→│ 4 │ └───┘ }
上面的 B1 代表被 Box 分配到堆上的值 1。
可以看出智能指针 vec2 依然是存储在栈上,然后指针指向一个堆上的数组,该数组中每个元素都是一个 Box 智能指针,最终 Box 智能指针又指向了存储在堆上的实际值。
因此当我们从数组中取出某个元素时,取到的是对应的智能指针 Box,需要对该智能指针进行解引用,才能取出最终的值:
fn main() { let arr = vec![Box::new(1), Box::new(2)]; let (first, second) = (&arr[0], &arr[1]); let sum = **first + **second; }
以上代码有几个值得注意的点:
- 使用
&借用数组中的元素,否则会报所有权错误 - 表达式不能隐式的解引用,因此必须使用
**做两次解引用,第一次将&Box<i32>类型转成Box<i32>,第二次将Box<i32>转成i32
Box::leak
Box 中还提供了一个非常有用的关联函数:Box::leak,它可以消费掉 Box 并且强制目标值从内存中泄漏,读者可能会觉得,这有啥用啊?
其实还真有点用,例如,你可以把一个 String 类型,变成一个 'static 生命周期的 &str 类型:
fn main() { let s = gen_static_str(); println!("{}", s); } fn gen_static_str() -> &'static str{ let mut s = String::new(); s.push_str("hello, world"); Box::leak(s.into_boxed_str()) }
在之前的代码中,如果 String 创建于函数中,那么返回它的唯一方法就是转移所有权给调用者 fn move_str() -> String,而通过 Box::leak 我们不仅返回了一个 &str 字符串切片,它还是 'static 生命周期的!
要知道真正具有 'static 生命周期的往往都是编译期就创建的值,例如 let v = "hello, world",这里 v 是直接打包到二进制可执行文件中的,因此该字符串具有 'static 生命周期,再比如 const 常量。
又有读者要问了,我还可以手动为变量标注 'static 啊。其实你标注的 'static 只是用来忽悠编译器的,但是超出作用域,一样被释放回收。而使用 Box::leak 就可以将一个运行期的值转为 'static。
使用场景
光看上面的描述,大家可能还是云里雾里、一头雾水。
那么我说一个简单的场景,你需要一个在运行期初始化的值,但是可以全局有效,也就是和整个程序活得一样久,那么就可以使用 Box::leak,例如有一个存储配置的结构体实例,它是在运行期动态插入内容,那么就可以将其转为全局有效,虽然 Rc/Arc 也可以实现此功能,但是 Box::leak 是性能最高的。
总结
Box 背后是调用 jemalloc 来做内存管理,所以堆上的空间无需我们的手动管理。与此类似,带 GC 的语言中的对象也是借助于 Box 概念来实现的,一切皆对象 = 一切皆 Box, 只不过我们无需自己去 Box 罢了。
其实很多时候,编译器的鞭笞可以助我们更快的成长,例如所有权规则里的借用、move、生命周期就是编译器在教我们做人,哦不是,是教我们深刻理解堆栈、内存布局、作用域等等你在其它 GC 语言无需去关注的东西。刚开始是很痛苦,但是一旦熟悉了这套规则,写代码的效率和代码本身的质量将飞速上升,直到你可以用 Java 开发的效率写出 Java 代码不可企及的性能和安全性,最终 Rust 语言所谓的开发效率低、心智负担高,对你来说终究不是个事。
因此,不要怪 Rust,它只是在帮我们成为那个更好的程序员,而这些苦难终究成为我们走向优秀的垫脚石。
Deref 解引用
在开始之前,我们先来看一段代码:
#![allow(unused)] fn main() { #[derive(Debug)] struct Person { name: String, age: u8 } impl Person { fn new(name: String, age: u8) -> Self { Person { name, age} } fn display(self: &mut Person, age: u8) { let Person{name, age} = &self; } } }
以上代码有一个很奇怪的地方:在 display 方法中,self 是 &mut Person 的类型,接着我们对其取了一次引用 &self,此时 &self 的类型是 &&mut Person,然后我们又将其和 Person 类型进行匹配,取出其中的值。
那么问题来了,Rust 不是号称安全的语言吗?为何允许将 &&mut Person 跟 Person 进行匹配呢?答案就在本章节中,等大家学完后,再回头自己来解决这个问题 :) 下面正式开始咱们的新章节学习。
何为智能指针?能不让你写出 ****s 形式的解引用,我认为就是智能: ),智能指针的名称来源,主要就在于它实现了 Deref 和 Drop 特征,这两个特征可以智能地帮助我们节省使用上的负担:
Deref可以让智能指针像引用那样工作,这样你就可以写出同时支持智能指针和引用的代码,例如*TDrop允许你指定智能指针超出作用域后自动执行的代码,例如做一些数据清除等收尾工作
先来看看 Deref 特征是如何工作的。
通过 * 获取引用背后的值
在正式讲解 Deref 之前,我们先来看下常规引用的解引用。
常规引用是一个指针类型,包含了目标数据存储的内存地址。对常规引用使用 * 操作符,就可以通过解引用的方式获取到内存地址对应的数据值:
fn main() { let x = 5; let y = &x; assert_eq!(5, x); assert_eq!(5, *y); }
这里 y 就是一个常规引用,包含了值 5 所在的内存地址,然后通过解引用 *y,我们获取到了值 5。如果你试图执行 assert_eq!(5, y);,代码就会无情报错,因为你无法将一个引用与一个数值做比较:
error[E0277]: can't compare `{integer}` with `&{integer}` //无法将{integer} 与&{integer}进行比较
--> src/main.rs:6:5
|
6 | assert_eq!(5, y);
| ^^^^^^^^^^^^^^^^^ no implementation for `{integer} == &{integer}`
|
= help: the trait `PartialEq<&{integer}>` is not implemented for `{integer}`
// 你需要为{integer}实现用于比较的特征PartialEq<&{integer}>
智能指针解引用
上面所说的解引用方式和其它大多数语言并无区别,但是 Rust 中将解引用提升到了一个新高度。考虑一下智能指针,它是一个结构体类型,如果你直接对它进行 *myStruct,显然编译器不知道该如何办,因此我们可以为智能指针结构体实现 Deref 特征。
实现 Deref 后的智能指针结构体,就可以像普通引用一样,通过 * 进行解引用,例如 Box<T> 智能指针:
fn main() { let x = Box::new(1); let sum = *x + 1; }
智能指针 x 被 * 解引用为 i32 类型的值 1,然后再进行求和。
定义自己的智能指针
现在,让我们一起来实现一个智能指针,功能上类似 Box<T>。由于 Box<T> 本身很简单,并没有包含类如长度、最大长度等信息,因此用一个元组结构体即可。
#![allow(unused)] fn main() { struct MyBox<T>(T); impl<T> MyBox<T> { fn new(x: T) -> MyBox<T> { MyBox(x) } } }
跟 Box<T> 一样,我们的智能指针也持有一个 T 类型的值,然后使用关联函数 MyBox::new 来创建智能指针。由于还未实现 Deref 特征,此时使用 * 肯定会报错:
fn main() { let y = MyBox::new(5); assert_eq!(5, *y); }
运行后,报错如下:
error[E0614]: type `MyBox<{integer}>` cannot be dereferenced
--> src/main.rs:12:19
|
12 | assert_eq!(5, *y);
| ^^
为智能指针实现 Deref 特征
现在来为 MyBox 实现 Deref 特征,以支持 * 解引用操作符:
#![allow(unused)] fn main() { use std::ops::Deref; impl<T> Deref for MyBox<T> { type Target = T; fn deref(&self) -> &Self::Target { &self.0 } } }
很简单,当解引用 MyBox 智能指针时,返回元组结构体中的元素 &self.0,有几点要注意的:
- 在
Deref特征中声明了关联类型Target,在之前章节中介绍过,关联类型主要是为了提升代码可读性 deref返回的是一个常规引用,可以被*进行解引用
之前报错的代码此时已能顺利编译通过。当然,标准库实现的智能指针要考虑很多边边角角情况,肯定比我们的实现要复杂。
* 背后的原理
当我们对智能指针 Box 进行解引用时,实际上 Rust 为我们调用了以下方法:
#![allow(unused)] fn main() { *(y.deref()) }
首先调用 deref 方法返回值的常规引用,然后通过 * 对常规引用进行解引用,最终获取到目标值。
至于 Rust 为何要使用这个有点啰嗦的方式实现,原因在于所有权系统的存在。如果 deref 方法直接返回一个值,而不是引用,那么该值的所有权将被转移给调用者,而我们不希望调用者仅仅只是 *T 一下,就拿走了智能指针中包含的值。
需要注意的是,* 不会无限递归替换,从 *y 到 *(y.deref()) 只会发生一次,而不会继续进行替换然后产生形如 *((y.deref()).deref()) 的怪物。
函数和方法中的隐式 Deref 转换
对于函数和方法的传参,Rust 提供了一个极其有用的隐式转换:Deref 转换。若一个类型实现了 Deref 特征,那它的引用在传给函数或方法时,会根据参数签名来决定是否进行隐式的 Deref 转换,例如:
fn main() { let s = String::from("hello world"); display(&s) } fn display(s: &str) { println!("{}",s); }
以上代码有几点值得注意:
String实现了Deref特征,可以在需要时自动被转换为&str类型&s是一个&String类型,当它被传给display函数时,自动通过Deref转换成了&str- 必须使用
&s的方式来触发Deref(仅引用类型的实参才会触发自动解引用)
连续的隐式 Deref 转换
如果你以为 Deref 仅仅这点作用,那就大错特错了。Deref 可以支持连续的隐式转换,直到找到适合的形式为止:
fn main() { let s = MyBox::new(String::from("hello world")); display(&s) } fn display(s: &str) { println!("{}",s); }
这里我们使用了之前自定义的智能指针 MyBox,并将其通过连续的隐式转换变成 &str 类型:首先 MyBox 被 Deref 成 String 类型,结果并不能满足 display 函数参数的要求,编译器发现 String 还可以继续 Deref 成 &str,最终成功的匹配了函数参数。
想象一下,假如 Rust 没有提供这种隐式转换,我们该如何调用 display 函数?
fn main() { let m = MyBox::new(String::from("Rust")); display(&(*m)[..]); }
结果不言而喻,肯定是 &s 的方式优秀得多。总之,当参与其中的类型定义了 Deref 特征时,Rust 会分析该类型并且连续使用 Deref 直到最终获得一个引用来匹配函数或者方法的参数类型,这种行为完全不会造成任何的性能损耗,因为完全是在编译期完成。
但是 Deref 并不是没有缺点,缺点就是:如果你不知道某个类型是否实现了 Deref 特征,那么在看到某段代码时,并不能在第一时间反应过来该代码发生了隐式的 Deref 转换。事实上,不仅仅是 Deref,在 Rust 中还有各种 From/Into 等等会给阅读代码带来一定负担的特征。还是那句话,一切选择都是权衡,有得必有失,得了代码的简洁性,往往就失去了可读性,Go 语言就是一个刚好相反的例子。
再来看一下在方法、赋值中自动应用 Deref 的例子:
fn main() { let s = MyBox::new(String::from("hello, world")); let s1: &str = &s; let s2: String = s.to_string(); }
对于 s1,我们通过两次 Deref 将 &str 类型的值赋给了它(赋值操作需要手动解引用);而对于 s2,我们在其上直接调用方法 to_string,实际上 MyBox 根本没有没有实现该方法,能调用 to_string,完全是因为编译器对 MyBox 应用了 Deref 的结果(方法调用会自动解引用)。
Deref 规则总结
在上面,我们零碎的介绍了不少关于 Deref 特征的知识,下面来通过较为正式的方式来对其规则进行下总结。
一个类型为 T 的对象 foo,如果 T: Deref<Target=U>,那么,相关 foo 的引用 &foo 在应用的时候会自动转换为 &U。
粗看这条规则,貌似有点类似于 AsRef,而跟 解引用 似乎风马牛不相及,实际里面有些玄妙之处。
引用归一化
Rust 编译器实际上只能对 &v 形式的引用进行解引用操作,那么问题来了,如果是一个智能指针或者 &&&&v 类型的呢? 该如何对这两个进行解引用?
答案是:Rust 会在解引用时自动把智能指针和 &&&&v 做引用归一化操作,转换成 &v 形式,最终再对 &v 进行解引用:
- 把智能指针(比如在库中定义的,Box、Rc、Arc、Cow 等)从结构体脱壳为内部的引用类型,也就是转成结构体内部的
&v - 把多重
&,例如&&&&&&&v,归一成&v
关于第二种情况,这么干巴巴的说,也许大家会迷迷糊糊的,我们来看一段标准库源码:
#![allow(unused)] fn main() { impl<T: ?Sized> Deref for &T { type Target = T; fn deref(&self) -> &T { *self } } }
在这段源码中,&T 被自动解引用为 T,也就是 &T: Deref<Target=T> 。 按照这个代码,&&&&T 会被自动解引用为 &&&T,然后再自动解引用为 &&T,以此类推, 直到最终变成 &T。
PS: 以下是 LLVM 编译后的部分中间层代码:
#![allow(unused)] fn main() { // Rust 代码 let mut _2: &i32; let _3: &&&&i32; bb0: { _2 = (*(*(*_3))) } }
几个例子
#![allow(unused)] fn main() { fn foo(s: &str) {} // 由于 String 实现了 Deref<Target=str> let owned = "Hello".to_string(); // 因此下面的函数可以正常运行: foo(&owned); }
#![allow(unused)] fn main() { use std::rc::Rc; fn foo(s: &str) {} // String 实现了 Deref<Target=str> let owned = "Hello".to_string(); // 且 Rc 智能指针可以被自动脱壳为内部的 `owned` 引用: &String ,然后 &String 再自动解引用为 &str let counted = Rc::new(owned); // 因此下面的函数可以正常运行: foo(&counted); }
#![allow(unused)] fn main() { struct Foo; impl Foo { fn foo(&self) { println!("Foo"); } } let f = &&Foo; f.foo(); (&f).foo(); (&&f).foo(); (&&&&&&&&f).foo(); }
三种 Deref 转换
在之前,我们讲的都是不可变的 Deref 转换,实际上 Rust 还支持将一个可变的引用转换成另一个可变的引用以及将一个可变引用转换成不可变的引用,规则如下:
- 当
T: Deref<Target=U>,可以将&T转换成&U,也就是我们之前看到的例子 - 当
T: DerefMut<Target=U>,可以将&mut T转换成&mut U - 当
T: Deref<Target=U>,可以将&mut T转换成&U
来看一个关于 DerefMut 的例子:
struct MyBox<T> { v: T, } impl<T> MyBox<T> { fn new(x: T) -> MyBox<T> { MyBox { v: x } } } use std::ops::Deref; impl<T> Deref for MyBox<T> { type Target = T; fn deref(&self) -> &Self::Target { &self.v } } use std::ops::DerefMut; impl<T> DerefMut for MyBox<T> { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.v } } fn main() { let mut s = MyBox::new(String::from("hello, ")); display(&mut s) } fn display(s: &mut String) { s.push_str("world"); println!("{}", s); }
以上代码有几点值得注意:
- 要实现
DerefMut必须要先实现Deref特征:pub trait DerefMut: Deref T: DerefMut<Target=U>解读:将&mut T类型通过DerefMut特征的方法转换为&mut U类型,对应上例中,就是将&mut MyBox<String>转换为&mut String
对于上述三条规则中的第三条,它比另外两条稍微复杂了点:Rust 可以把可变引用隐式的转换成不可变引用,但反之则不行。
如果从 Rust 的所有权和借用规则的角度考虑,当你拥有一个可变的引用,那该引用肯定是对应数据的唯一借用,那么此时将可变引用变成不可变引用并不会破坏借用规则;但是如果你拥有一个不可变引用,那同时可能还存在其它几个不可变的引用,如果此时将其中一个不可变引用转换成可变引用,就变成了可变引用与不可变引用的共存,最终破坏了借用规则。
总结
Deref 可以说是 Rust 中最常见的隐式类型转换,而且它可以连续的实现如 Box<String> -> String -> &str 的隐式转换,只要链条上的类型实现了 Deref 特征。
我们也可以为自己的类型实现 Deref 特征,但是原则上来说,只应该为自定义的智能指针实现 Deref。例如,虽然你可以为自己的自定义数组类型实现 Deref 以避免 myArr.0[0] 的使用形式,但是 Rust 官方并不推荐这么做,特别是在你开发三方库时。
Drop 释放资源
在 Rust 中,我们之所以可以一拳打跑 GC 的同时一脚踢翻手动资源回收,主要就归功于 Drop 特征,同时它也是智能指针的必备特征之一。
学习目标
如何自动和手动释放资源及执行指定的收尾工作
Rust 中的资源回收
在一些无 GC 语言中,程序员在一个变量无需再被使用时,需要手动释放它占用的内存资源,如果忘记了,那么就会发生内存泄漏,最终臭名昭著的 OOM 问题可能就会发生。
而在 Rust 中,你可以指定在一个变量超出作用域时,执行一段特定的代码,最终编译器将帮你自动插入这段收尾代码。这样,就无需在每一个使用该变量的地方,都写一段代码来进行收尾工作和资源释放。不禁让人感叹,Rust 的大腿真粗,香!
没错,指定这样一段收尾工作靠的就是咱这章的主角 - Drop 特征。
一个不那么简单的 Drop 例子
struct HasDrop1; struct HasDrop2; impl Drop for HasDrop1 { fn drop(&mut self) { println!("Dropping HasDrop1!"); } } impl Drop for HasDrop2 { fn drop(&mut self) { println!("Dropping HasDrop2!"); } } struct HasTwoDrops { one: HasDrop1, two: HasDrop2, } impl Drop for HasTwoDrops { fn drop(&mut self) { println!("Dropping HasTwoDrops!"); } } struct Foo; impl Drop for Foo { fn drop(&mut self) { println!("Dropping Foo!") } } fn main() { let _x = HasTwoDrops { two: HasDrop2, one: HasDrop1, }; let _foo = Foo; println!("Running!"); }
上面代码虽然长,但是目的其实很单纯,就是为了观察不同情况下变量级别的、结构体内部字段的 Drop,有几点值得注意:
Drop特征中的drop方法借用了目标的可变引用,而不是拿走了所有权,这里先设置一个悬念,后边会讲- 结构体中每个字段都有自己的
Drop
来看看输出:
Running!
Dropping Foo!
Dropping HasTwoDrops!
Dropping HasDrop1!
Dropping HasDrop2!
嗯,结果符合预期,每个资源都成功的执行了收尾工作,虽然 println! 这种收尾工作毫无意义 =,=
Drop 的顺序
观察以上输出,我们可以得出以下关于 Drop 顺序的结论
- 变量级别,按照逆序的方式,
_x在_foo之前创建,因此_x在_foo之后被drop - 结构体内部,按照顺序的方式,结构体
_x中的字段按照定义中的顺序依次drop
没有实现 Drop 的结构体
实际上,就算你不为 _x 结构体实现 Drop 特征,它内部的两个字段依然会调用 drop,移除以下代码,并观察输出:
#![allow(unused)] fn main() { impl Drop for HasTwoDrops { fn drop(&mut self) { println!("Dropping HasTwoDrops!"); } } }
原因在于,Rust 自动为几乎所有类型都实现了 Drop 特征,因此就算你不手动为结构体实现 Drop,它依然会调用默认实现的 drop 函数,同时再调用每个字段的 drop 方法,最终打印出:
Dropping HasDrop1!
Dropping HasDrop2!
手动回收
当使用智能指针来管理锁的时候,你可能希望提前释放这个锁,然后让其它代码能及时获得锁,此时就需要提前去手动 drop。
但是在之前我们提到一个悬念,Drop::drop 只是借用了目标值的可变引用,所以,就算你提前调用了 drop,后面的代码依然可以使用目标值,但是这就会访问一个并不存在的值,非常不安全,好在 Rust 会阻止你:
#[derive(Debug)] struct Foo; impl Drop for Foo { fn drop(&mut self) { println!("Dropping Foo!") } } fn main() { let foo = Foo; foo.drop(); println!("Running!:{:?}", foo); }
报错如下:
error[E0040]: explicit use of destructor method
--> src/main.rs:37:9
|
37 | foo.drop();
| ----^^^^--
| | |
| | explicit destructor calls not allowed
| help: consider using `drop` function: `drop(foo)`
如上所示,编译器直接阻止了我们调用 Drop 特征的 drop 方法,原因是对于 Rust 而言,不允许显式地调用析构函数(这是一个用来清理实例的通用编程概念)。好在在报错的同时,编译器还给出了一个提示:使用 drop 函数。
针对编译器提示的 drop 函数,我们可以大胆推测下:它能够拿走目标值的所有权。现在来看看这个猜测正确与否,以下是 std::mem::drop 函数的签名:
#![allow(unused)] fn main() { pub fn drop<T>(_x: T) }
如上所示,drop 函数确实拿走了目标值的所有权,来验证下:
fn main() { let foo = Foo; drop(foo); // 以下代码会报错:借用了所有权被转移的值 // println!("Running!:{:?}", foo); }
Bingo,完美拿走了所有权,而且这种实现保证了后续的使用必定会导致编译错误,因此非常安全!
细心的同学可能已经注意到,这里直接调用了 drop 函数,并没有引入任何模块信息,原因是该函数在std::prelude里。
Drop 使用场景
对于 Drop 而言,主要有两个功能:
- 回收内存资源
- 执行一些收尾工作
对于第二点,在之前我们已经详细介绍过,因此这里主要对第一点进行下简单说明。
在绝大多数情况下,我们都无需手动去 drop 以回收内存资源,因为 Rust 会自动帮我们完成这些工作,它甚至会对复杂类型的每个字段都单独的调用 drop 进行回收!但是确实有极少数情况,需要你自己来回收资源的,例如文件描述符、网络 socket 等,当这些值超出作用域不再使用时,就需要进行关闭以释放相关的资源,在这些情况下,就需要使用者自己来解决 Drop 的问题。
互斥的 Copy 和 Drop
我们无法为一个类型同时实现 Copy 和 Drop 特征。因为实现了 Copy 的特征会被编译器隐式的复制,因此非常难以预测析构函数执行的时间和频率。因此这些实现了 Copy 的类型无法拥有析构函数。
#![allow(unused)] fn main() { #[derive(Copy)] struct Foo; impl Drop for Foo { fn drop(&mut self) { println!("Dropping Foo!") } } }
以上代码报错如下:
error[E0184]: the trait `Copy` may not be implemented for this type; the type has a destructor
--> src/main.rs:24:10
|
24 | #[derive(Copy)]
| ^^^^ Copy not allowed on types with destructors
总结
Drop 可以用于许多方面,来使得资源清理及收尾工作变得方便和安全,甚至可以用其创建我们自己的内存分配器!通过 Drop 特征和 Rust 所有权系统,你无需担心之后的代码清理,Rust 会自动考虑这些问题。
我们也无需担心意外的清理掉仍在使用的值,这会造成编译器错误:所有权系统确保引用总是有效的,也会确保 drop 只会在值不再被使用时被调用一次。
Rc 与 Arc
Rust 所有权机制要求一个值只能有一个所有者,在大多数情况下,都没有问题,但是考虑以下情况:
- 在图数据结构中,多个边可能会拥有同一个节点,该节点直到没有边指向它时,才应该被释放清理
- 在多线程中,多个线程可能会持有同一个数据,但是你受限于 Rust 的安全机制,无法同时获取该数据的可变引用
以上场景不是很常见,但是一旦遇到,就非常棘手,为了解决此类问题,Rust 在所有权机制之外又引入了额外的措施来简化相应的实现:通过引用计数的方式,允许一个数据资源在同一时刻拥有多个所有者。
这种实现机制就是 Rc 和 Arc,前者适用于单线程,后者适用于多线程。由于二者大部分情况下都相同,因此本章将以 Rc 作为讲解主体,对于 Arc 的不同之处,另外进行单独讲解。
Rc<T>
引用计数(reference counting),顾名思义,通过记录一个数据被引用的次数来确定该数据是否正在被使用。当引用次数归零时,就代表该数据不再被使用,因此可以被清理释放。
而 Rc 正是引用计数的英文缩写。当我们希望在堆上分配一个对象供程序的多个部分使用且无法确定哪个部分最后一个结束时,就可以使用 Rc 成为数据值的所有者,例如之前提到的多线程场景就非常适合。
下面是经典的所有权被转移导致报错的例子:
fn main() { let s = String::from("hello, world"); // s在这里被转移给a let a = Box::new(s); // 报错!此处继续尝试将 s 转移给 b let b = Box::new(s); }
使用 Rc 就可以轻易解决:
use std::rc::Rc; fn main() { let a = Rc::new(String::from("hello, world")); let b = Rc::clone(&a); assert_eq!(2, Rc::strong_count(&a)); assert_eq!(Rc::strong_count(&a), Rc::strong_count(&b)) }
以上代码我们使用 Rc::new 创建了一个新的 Rc<String> 智能指针并赋给变量 a,该指针指向底层的字符串数据。
智能指针 Rc<T> 在创建时,还会将引用计数加 1,此时获取引用计数的关联函数 Rc::strong_count 返回的值将是 1。
Rc::clone
接着,我们又使用 Rc::clone 克隆了一份智能指针 Rc<String>,并将该智能指针的引用计数增加到 2。
由于 a 和 b 是同一个智能指针的两个副本,因此通过它们两个获取引用计数的结果都是 2。
不要被 clone 字样所迷惑,以为所有的 clone 都是深拷贝。这里的 clone 仅仅复制了智能指针并增加了引用计数,并没有克隆底层数据,因此 a 和 b 是共享了底层的字符串 s,这种复制效率是非常高的。当然你也可以使用 a.clone() 的方式来克隆,但是从可读性角度,我们更加推荐 Rc::clone 的方式。
实际上在 Rust 中,还有不少 clone 都是浅拷贝,例如迭代器的克隆。
观察引用计数的变化
使用关联函数 Rc::strong_count 可以获取当前引用计数的值,我们来观察下引用计数如何随着变量声明、释放而变化:
use std::rc::Rc; fn main() { let a = Rc::new(String::from("test ref counting")); println!("count after creating a = {}", Rc::strong_count(&a)); let b = Rc::clone(&a); println!("count after creating b = {}", Rc::strong_count(&a)); { let c = Rc::clone(&a); println!("count after creating c = {}", Rc::strong_count(&c)); } println!("count after c goes out of scope = {}", Rc::strong_count(&a)); }
有几点值得注意:
- 由于变量
c在语句块内部声明,当离开语句块时它会因为超出作用域而被释放,所以引用计数会减少 1,事实上这个得益于Rc<T>实现了Drop特征 a、b、c三个智能指针引用计数都是同样的,并且共享底层的数据,因此打印计数时用哪个都行- 无法看到的是:当
a、b超出作用域后,引用计数会变成 0,最终智能指针和它指向的底层字符串都会被清理释放
不可变引用
事实上,Rc<T> 是指向底层数据的不可变的引用,因此你无法通过它来修改数据,这也符合 Rust 的借用规则:要么存在多个不可变借用,要么只能存在一个可变借用。
但是实际开发中我们往往需要对数据进行修改,这时单独使用 Rc<T> 无法满足我们的需求,需要配合其它数据类型来一起使用,例如内部可变性的 RefCell<T> 类型以及互斥锁 Mutex<T>。事实上,在多线程编程中,Arc 跟 Mutex 锁的组合使用非常常见,它们既可以让我们在不同的线程中共享数据,又允许在各个线程中对其进行修改。
一个综合例子
考虑一个场景,有很多小工具,每个工具都有自己的主人,但是存在多个工具属于同一个主人的情况,此时使用 Rc<T> 就非常适合:
use std::rc::Rc; struct Owner { name: String, // ...其它字段 } struct Gadget { id: i32, owner: Rc<Owner>, // ...其它字段 } fn main() { // 创建一个基于引用计数的 `Owner`. let gadget_owner: Rc<Owner> = Rc::new(Owner { name: "Gadget Man".to_string(), }); // 创建两个不同的工具,它们属于同一个主人 let gadget1 = Gadget { id: 1, owner: Rc::clone(&gadget_owner), }; let gadget2 = Gadget { id: 2, owner: Rc::clone(&gadget_owner), }; // 释放掉第一个 `Rc<Owner>` drop(gadget_owner); // 尽管在上面我们释放了 gadget_owner,但是依然可以在这里使用 owner 的信息 // 原因是在 drop 之前,存在三个指向 Gadget Man 的智能指针引用,上面仅仅 // drop 掉其中一个智能指针引用,而不是 drop 掉 owner 数据,外面还有两个 // 引用指向底层的 owner 数据,引用计数尚未清零 // 因此 owner 数据依然可以被使用 println!("Gadget {} owned by {}", gadget1.id, gadget1.owner.name); println!("Gadget {} owned by {}", gadget2.id, gadget2.owner.name); // 在函数最后,`gadget1` 和 `gadget2` 也被释放,最终引用计数归零,随后底层 // 数据也被清理释放 }
以上代码很好的展示了 Rc<T> 的用途,当然你也可以用借用的方式,但是实现起来就会复杂得多,而且随着 Gadget 在代码的各个地方使用,引用生命周期也将变得更加复杂,毕竟结构体中的引用类型,总是令人不那么愉快,对不?
Rc 简单总结
Rc/Arc是不可变引用,你无法修改它指向的值,只能进行读取,如果要修改,需要配合后面章节的内部可变性RefCell或互斥锁Mutex- 一旦最后一个拥有者消失,则资源会自动被回收,这个生命周期是在编译期就确定下来的
Rc只能用于同一线程内部,想要用于线程之间的对象共享,你需要使用ArcRc<T>是一个智能指针,实现了Deref特征,因此你无需先解开Rc指针,再使用里面的T,而是可以直接使用T,例如上例中的gadget1.owner.name
多线程无力的 Rc<T>
来看看在多线程场景使用 Rc<T> 会如何:
use std::rc::Rc; use std::thread; fn main() { let s = Rc::new(String::from("多线程漫游者")); for _ in 0..10 { let s = Rc::clone(&s); let handle = thread::spawn(move || { println!("{}", s) }); } }
由于我们还没有学习多线程的章节,上面的例子就特地简化了相关的实现。首先通过 thread::spawn 创建一个线程,然后使用 move 关键字把克隆出的 s 的所有权转移到线程中。
能够实现这一点,完全得益于 Rc 带来的多所有权机制,但是以上代码会报错:
error[E0277]: `Rc<String>` cannot be sent between threads safely
表面原因是 Rc<T> 不能在线程间安全的传递,实际上是因为它没有实现 Send 特征,而该特征是恰恰是多线程间传递数据的关键,我们会在多线程章节中进行讲解。
当然,还有更深层的原因:由于 Rc<T> 需要管理引用计数,但是该计数器并没有使用任何并发原语,因此无法实现原子化的计数操作,最终会导致计数错误。
好在天无绝人之路,一起来看看 Rust 为我们提供的功能类似但是多线程安全的 Arc。
Arc
Arc 是 Atomic Rc 的缩写,顾名思义:原子化的 Rc<T> 智能指针。原子化是一种并发原语,我们在后续章节会进行深入讲解,这里你只要知道它能保证我们的数据能够安全的在线程间共享即可。
Arc 的性能损耗
你可能好奇,为何不直接使用 Arc,还要画蛇添足弄一个 Rc,还有 Rust 的基本数据类型、标准库数据类型为什么不自动实现原子化操作?这样就不存在线程不安全的问题了。
原因在于原子化或者其它锁虽然可以带来的线程安全,但是都会伴随着性能损耗,而且这种性能损耗还不小。因此 Rust 把这种选择权交给你,毕竟需要线程安全的代码其实占比并不高,大部分时候我们开发的程序都在一个线程内。
Arc 和 Rc 拥有完全一样的 API,修改起来很简单:
use std::sync::Arc; use std::thread; fn main() { let s = Arc::new(String::from("多线程漫游者")); for _ in 0..10 { let s = Arc::clone(&s); let handle = thread::spawn(move || { println!("{}", s) }); } }
对了,两者还有一点区别:Arc 和 Rc 并没有定义在同一个模块,前者通过 use std::sync::Arc 来引入,后者通过 use std::rc::Rc。
总结
在 Rust 中,所有权机制保证了一个数据只会有一个所有者,但如果你想要在图数据结构、多线程等场景中共享数据,这种机制会成为极大的阻碍。好在 Rust 为我们提供了智能指针 Rc 和 Arc,使用它们就能实现多个所有者共享一个数据的功能。
Rc 和 Arc 的区别在于,后者是原子化实现的引用计数,因此是线程安全的,可以用于多线程中共享数据。
这两者都是只读的,如果想要实现内部数据可修改,必须配合内部可变性 RefCell 或者互斥锁 Mutex 来一起使用。
Cell 和 RefCell
Rust 的编译器之严格,可以说是举世无双。特别是在所有权方面,Rust 通过严格的规则来保证所有权和借用的正确性,最终为程序的安全保驾护航。
但是严格是一把双刃剑,带来安全提升的同时,损失了灵活性,有时甚至会让用户痛苦不堪、怨声载道。因此 Rust 提供了 Cell 和 RefCell 用于内部可变性,简而言之,可以在拥有不可变引用的同时修改目标数据,对于正常的代码实现来说,这个是不可能做到的(要么一个可变借用,要么多个不可变借用)。
内部可变性的实现是因为 Rust 使用了
unsafe来做到这一点,但是对于使用者来说,这些都是透明的,因为这些不安全代码都被封装到了安全的 API 中
Cell
Cell 和 RefCell 在功能上没有区别,区别在于 Cell<T> 适用于 T 实现 Copy 的情况:
use std::cell::Cell; fn main() { let c = Cell::new("asdf"); let one = c.get(); c.set("qwer"); let two = c.get(); println!("{},{}", one, two); }
以上代码展示了 Cell 的基本用法,有几点值得注意:
- "asdf" 是
&str类型,它实现了Copy特征 c.get用来取值,c.set用来设置新值
取到值保存在 one 变量后,还能同时进行修改,这个违背了 Rust 的借用规则,但是由于 Cell 的存在,我们很优雅地做到了这一点,但是如果你尝试在 Cell 中存放String:
#![allow(unused)] fn main() { let c = Cell::new(String::from("asdf")); }
编译器会立刻报错,因为 String 没有实现 Copy 特征:
| pub struct String {
| ----------------- doesn't satisfy `String: Copy`
|
= note: the following trait bounds were not satisfied:
`String: Copy`
RefCell
由于 Cell 类型针对的是实现了 Copy 特征的值类型,因此在实际开发中,Cell 使用的并不多,因为我们要解决的往往是可变、不可变引用共存导致的问题,此时就需要借助于 RefCell 来达成目的。
我们可以将所有权、借用规则与这些智能指针做一个对比:
| Rust 规则 | 智能指针带来的额外规则 |
|---|---|
| 一个数据只有一个所有者 | Rc/Arc让一个数据可以拥有多个所有者 |
| 要么多个不可变借用,要么一个可变借用 | RefCell实现编译期可变、不可变引用共存 |
| 违背规则导致编译错误 | 违背规则导致运行时panic |
可以看出,Rc/Arc 和 RefCell 合在一起,解决了 Rust 中严苛的所有权和借用规则带来的某些场景下难使用的问题。但是它们并不是银弹,例如 RefCell 实际上并没有解决可变引用和引用可以共存的问题,只是将报错从编译期推迟到运行时,从编译器错误变成了 panic 异常:
use std::cell::RefCell; fn main() { let s = RefCell::new(String::from("hello, world")); let s1 = s.borrow(); let s2 = s.borrow_mut(); println!("{},{}", s1, s2); }
上面代码在编译期不会报任何错误,你可以顺利运行程序:
thread 'main' panicked at 'already borrowed: BorrowMutError', src/main.rs:6:16
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
但是依然会因为违背了借用规则导致了运行期 panic,这非常像中国的天网,它也许会被罪犯蒙蔽一时,但是并不会被蒙蔽一世,任何导致安全风险的存在都将不能被容忍,法网恢恢,疏而不漏。
RefCell 为何存在
相信肯定有读者有疑问了,这么做有任何意义吗?还不如在编译期报错,至少能提前发现问题,而且性能还更好。
存在即合理,究其根因,在于 Rust 编译期的宁可错杀,绝不放过的原则,当编译器不能确定你的代码是否正确时,就统统会判定为错误,因此难免会导致一些误报。
而 RefCell 正是用于你确信代码是正确的,而编译器却发生了误判时。
对于大型的复杂程序,也可以选择使用 RefCell 来让事情简化。例如在 Rust 编译器的ctxt结构体中有大量的 RefCell 类型的 map 字段,主要的原因是:这些 map 会被分散在各个地方的代码片段所广泛使用或修改。由于这种分散在各处的使用方式,导致了管理可变和不可变成为一件非常复杂的任务(甚至不可能),你很容易就碰到编译器抛出来的各种错误。而且 RefCell 的运行时错误在这种情况下也变得非常可爱:一旦有人做了不正确的使用,代码会 panic,然后告诉我们哪些借用冲突了。
总之,当你确信编译器误报但不知道该如何解决时,或者你有一个引用类型,需要被四处使用和修改然后导致借用关系难以管理时,都可以优先考虑使用 RefCell。
RefCell 简单总结
- 与
Cell用于可Copy的值不同,RefCell用于引用 RefCell只是将借用规则从编译期推迟到程序运行期,并不能帮你绕过这个规则RefCell适用于编译期误报或者一个引用被在多处代码使用、修改以至于难于管理借用关系时- 使用
RefCell时,违背借用规则会导致运行期的panic
选择 Cell 还是 RefCell
根据本文的内容,我们可以大概总结下两者的区别:
Cell只适用于Copy类型,用于提供值,而RefCell用于提供引用Cell不会panic,而RefCell会
性能比较
Cell 没有额外的性能损耗,例如以下两段代码的性能其实是一致的:
#![allow(unused)] fn main() { // code snipet 1 let x = Cell::new(1); let y = &x; let z = &x; x.set(2); y.set(3); z.set(4); println!("{}", x.get()); // code snipet 2 let mut x = 1; let y = &mut x; let z = &mut x; x = 2; *y = 3; *z = 4; println!("{}", x); }
虽然性能一致,但代码 1 拥有代码 2 不具有的优势:它能编译成功:)
与 Cell 的 zero cost 不同,RefCell 其实是有一点运行期开销的,原因是它包含了一个字大小的“借用状态”指示器,该指示器在每次运行时借用时都会被修改,进而产生一点开销。
总之,当非要使用内部可变性时,首选 Cell,只有你的类型没有实现 Copy 时,才去选择 RefCell。
内部可变性
之前我们提到 RefCell 具有内部可变性,何为内部可变性?简单来说,对一个不可变的值进行可变借用,但这个并不符合 Rust 的基本借用规则:
fn main() { let x = 5; let y = &mut x; }
上面的代码会报错,因为我们不能对一个不可变的值进行可变借用,这会破坏 Rust 的安全性保证,相反,你可以对一个可变值进行不可变借用。原因是:当值不可变时,可能会有多个不可变的引用指向它,此时若将修改其中一个为可变的,会造成可变引用与不可变引用共存的情况;而当值可变时,最多只会有一个可变引用指向它,将其修改为不可变,那么最终依然是只有一个不可变的引用指向它。
虽然基本借用规则是 Rust 的基石,然而在某些场景中,一个值可以在其方法内部被修改,同时对于其它代码不可变,是很有用的:
#![allow(unused)] fn main() { // 定义在外部库中的特征 pub trait Messenger { fn send(&self, msg: String); } // -------------------------- // 我们的代码中的数据结构和实现 struct MsgQueue { msg_cache: Vec<String>, } impl Messenger for MsgQueue { fn send(&self, msg: String) { self.msg_cache.push(msg) } } }
如上所示,外部库中定义了一个消息发送器特征 Messenger,它只有一个发送消息的功能:fn send(&self, msg: String),因为发送消息不需要修改自身,因此原作者在定义时,使用了 &self 的不可变借用,这个无可厚非。
我们要在自己的代码中使用该特征实现一个异步消息队列,出于性能的考虑,消息先写到本地缓存(内存)中,然后批量发送出去,因此在 send 方法中,需要将消息先行插入到本地缓存 msg_cache 中。但是问题来了,该 send 方法的签名是 &self,因此上述代码会报错:
error[E0596]: cannot borrow `self.msg_cache` as mutable, as it is behind a `&` reference
--> src/main.rs:11:9
|
2 | fn send(&self, msg: String);
| ----- help: consider changing that to be a mutable reference: `&mut self`
...
11 | self.msg_cache.push(msg)
| ^^^^^^^^^^^^^^^^^^ `self` is a `&` reference, so the data it refers to cannot be borrowed as mutable
在报错的同时,编译器大聪明还善意地给出了提示:将 &self 修改为 &mut self,但是。。。我们实现的特征是定义在外部库中,因此该签名根本不能修改。值此危急关头, RefCell 闪亮登场:
use std::cell::RefCell; pub trait Messenger { fn send(&self, msg: String); } pub struct MsgQueue { msg_cache: RefCell<Vec<String>>, } impl Messenger for MsgQueue { fn send(&self, msg: String) { self.msg_cache.borrow_mut().push(msg) } } fn main() { let mq = MsgQueue { msg_cache: RefCell::new(Vec::new()), }; mq.send("hello, world".to_string()); }
这个 MQ 功能很弱,但是并不妨碍我们演示内部可变性的核心用法:通过包裹一层 RefCell,成功的让 &self 中的 msg_cache 成为一个可变值,然后实现对其的修改。
Rc + RefCell 组合使用
在 Rust 中,一个常见的组合就是 Rc 和 RefCell 在一起使用,前者可以实现一个数据拥有多个所有者,后者可以实现数据的可变性:
use std::cell::RefCell; use std::rc::Rc; fn main() { let s = Rc::new(RefCell::new("我很善变,还拥有多个主人".to_string())); let s1 = s.clone(); let s2 = s.clone(); // let mut s2 = s.borrow_mut(); s2.borrow_mut().push_str(", on yeah!"); println!("{:?}\n{:?}\n{:?}", s, s1, s2); }
上面代码中,我们使用 RefCell<String> 包裹一个字符串,同时通过 Rc 创建了它的三个所有者:s、s1和s2,并且通过其中一个所有者 s2 对字符串内容进行了修改。
由于 Rc 的所有者们共享同一个底层的数据,因此当一个所有者修改了数据时,会导致全部所有者持有的数据都发生了变化。
程序的运行结果也在预料之中:
RefCell { value: "我很善变,还拥有多个主人, on yeah!" }
RefCell { value: "我很善变,还拥有多个主人, on yeah!" }
RefCell { value: "我很善变,还拥有多个主人, on yeah!" }
性能损耗
相信这两者组合在一起使用时,很多人会好奇到底性能如何,下面我们来简单分析下。
首先给出一个大概的结论,这两者结合在一起使用的性能其实非常高,大致相当于没有线程安全版本的 C++ std::shared_ptr 指针,事实上,C++ 这个指针的主要开销也在于原子性这个并发原语上,毕竟线程安全在哪个语言中开销都不小。
内存损耗
两者结合的数据结构与下面类似:
#![allow(unused)] fn main() { struct Wrapper<T> { // Rc strong_count: usize, weak_count: usize, // Refcell borrow_count: isize, // 包裹的数据 item: T, } }
从上面可以看出,从对内存的影响来看,仅仅多分配了三个usize/isize,并没有其它额外的负担。
CPU 损耗
从 CPU 来看,损耗如下:
- 对
Rc<T>解引用是免费的(编译期),但是*带来的间接取值并不免费 - 克隆
Rc<T>需要将当前的引用计数跟0和usize::Max进行一次比较,然后将计数值加 1 - 释放(drop)
Rc<T>需要将计数值减 1, 然后跟0进行一次比较 - 对
RefCell进行不可变借用,需要将isize类型的借用计数加 1,然后跟0进行比较 - 对
RefCell的不可变借用进行释放,需要将isize减 1 - 对
RefCell的可变借用大致流程跟上面差不多,但是需要先跟0比较,然后再减 1 - 对
RefCell的可变借用进行释放,需要将isize加 1
其实这些细节不必过于关注,只要知道 CPU 消耗也非常低,甚至编译器还会对此进行进一步优化!
CPU 缓存 Miss
唯一需要担心的可能就是这种组合数据结构对于 CPU 缓存是否亲和,这个我们无法证明,只能提出来存在这个可能性,最终的性能影响还需要在实际场景中进行测试。
总之,分析这两者组合的性能还挺复杂的,大概总结下:
- 从表面来看,它们带来的内存和 CPU 损耗都不大
- 但是由于
Rc额外的引入了一次间接取值(*),在少数场景下可能会造成性能上的显著损失 - CPU 缓存可能也不够亲和
通过 Cell::from_mut 解决借用冲突
在 Rust 1.37 版本中新增了两个非常实用的方法:
- Cell::from_mut,该方法将
&mut T转为&Cell<T> - Cell::as_slice_of_cells,该方法将
&Cell<[T]>转为&[Cell<T>]
这里我们不做深入的介绍,但是来看看如何使用这两个方法来解决一个常见的借用冲突问题:
#![allow(unused)] fn main() { fn is_even(i: i32) -> bool { i % 2 == 0 } fn retain_even(nums: &mut Vec<i32>) { let mut i = 0; for num in nums.iter().filter(|&num| is_even(*num)) { nums[i] = *num; i += 1; } nums.truncate(i); } }
以上代码会报错:
error[E0502]: cannot borrow `*nums` as mutable because it is also borrowed as immutable
--> src/main.rs:8:9
|
7 | for num in nums.iter().filter(|&num| is_even(*num)) {
| ----------------------------------------
| |
| immutable borrow occurs here
| immutable borrow later used here
8 | nums[i] = *num;
| ^^^^ mutable borrow occurs here
很明显,报错是因为同时借用了不可变与可变引用,你可以通过索引的方式来避免这个问题:
#![allow(unused)] fn main() { fn retain_even(nums: &mut Vec<i32>) { let mut i = 0; for j in 0..nums.len() { if is_even(nums[j]) { nums[i] = nums[j]; i += 1; } } nums.truncate(i); } }
但是这样就违背我们的初衷了,毕竟迭代器会让代码更加简洁,那么还有其它的办法吗?
这时就可以使用 Cell 新增的这两个方法:
#![allow(unused)] fn main() { use std::cell::Cell; fn retain_even(nums: &mut Vec<i32>) { let slice: &[Cell<i32>] = Cell::from_mut(&mut nums[..]) .as_slice_of_cells(); let mut i = 0; for num in slice.iter().filter(|num| is_even(num.get())) { slice[i].set(num.get()); i += 1; } nums.truncate(i); } }
此时代码将不会报错,因为 Cell 上的 set 方法获取的是不可变引用 pub fn set(&self, val: T)。
当然,以上代码的本质还是对 Cell 的运用,只不过这两个方法可以很方便的帮我们把 &mut [T] 类型转换成 &[Cell<T>] 类型。
总结
Cell 和 RefCell 都为我们带来了内部可变性这个重要特性,同时还将借用规则的检查从编译期推迟到运行期,但是这个检查并不能被绕过,该来早晚还是会来,RefCell 在运行期的报错会造成 panic。
RefCell 适用于编译器误报或者一个引用被在多个代码中使用、修改以至于难于管理借用关系时,还有就是需要内部可变性时。
从性能上看,RefCell 由于是非线程安全的,因此无需保证原子性,性能虽然有一点损耗,但是依然非常好,而 Cell 则完全不存在任何额外的性能损耗。
Rc 跟 RefCell 结合使用可以实现多个所有者共享同一份数据,非常好用,但是潜在的性能损耗也要考虑进去,建议对于热点代码使用时,做好 benchmark。
循环引用与自引用
实现一个链表是学习各大编程语言的常用技巧,但是在 Rust 中实现链表意味着····Hell,是的,你没看错,Welcome to hell。
链表在 Rust 中之所以这么难,完全是因为循环引用和自引用的问题引起的,这两个问题可以说综合了 Rust 的很多难点,难出了新高度,因此本书专门开辟一章,分为上下两篇,试图彻底解决这两个老大难。
本章难度较高,但是非常值得深入阅读,它会让你对 Rust 的理解上升到一个新的境界。
Weak 与循环引用
Rust 的安全性是众所周知的,但是不代表它不会内存泄漏。一个典型的例子就是同时使用 Rc<T> 和 RefCell<T> 创建循环引用,最终这些引用的计数都无法被归零,因此 Rc<T> 拥有的值也不会被释放清理。
何为循环引用
关于内存泄漏,如果你没有充足的 Rust 经验,可能都无法造出一份代码来再现它:
use crate::List::{Cons, Nil}; use std::cell::RefCell; use std::rc::Rc; #[derive(Debug)] enum List { Cons(i32, RefCell<Rc<List>>), Nil, } impl List { fn tail(&self) -> Option<&RefCell<Rc<List>>> { match self { Cons(_, item) => Some(item), Nil => None, } } } fn main() {}
这里我们创建一个有些复杂的枚举类型 List,这个类型很有意思,它的每个值都指向了另一个 List,此外,得益于 Rc 的使用还允许多个值指向一个 List:
如上图所示,每个矩形框节点都是一个 List 类型,它们或者是拥有值且指向另一个 List 的Cons,或者是一个没有值的终结点 Nil。同时,由于 RefCell 的使用,每个 List 所指向的 List 还能够被修改。
下面来使用一下这个复杂的 List 枚举:
fn main() { let a = Rc::new(Cons(5, RefCell::new(Rc::new(Nil)))); println!("a的初始化rc计数 = {}", Rc::strong_count(&a)); println!("a指向的节点 = {:?}", a.tail()); // 创建`b`到`a`的引用 let b = Rc::new(Cons(10, RefCell::new(Rc::clone(&a)))); println!("在b创建后,a的rc计数 = {}", Rc::strong_count(&a)); println!("b的初始化rc计数 = {}", Rc::strong_count(&b)); println!("b指向的节点 = {:?}", b.tail()); // 利用RefCell的可变性,创建了`a`到`b`的引用 if let Some(link) = a.tail() { *link.borrow_mut() = Rc::clone(&b); } println!("在更改a后,b的rc计数 = {}", Rc::strong_count(&b)); println!("在更改a后,a的rc计数 = {}", Rc::strong_count(&a)); // 下面一行println!将导致循环引用 // 我们可怜的8MB大小的main线程栈空间将被它冲垮,最终造成栈溢出 // println!("a next item = {:?}", a.tail()); }
这个类型定义看着复杂,使用起来更复杂!不过排除这些因素,我们可以清晰看出:
- 在创建了
a后,紧接着就使用a创建了b,因此b引用了a - 然后我们又利用
Rc克隆了b,然后通过RefCell的可变性,让a引用了b
至此我们成功创建了循环引用a-> b -> a -> b ····
先来观察下引用计数:
a的初始化rc计数 = 1
a指向的节点 = Some(RefCell { value: Nil })
在b创建后,a的rc计数 = 2
b的初始化rc计数 = 1
b指向的节点 = Some(RefCell { value: Cons(5, RefCell { value: Nil }) })
在更改a后,b的rc计数 = 2
在更改a后,a的rc计数 = 2
在 main 函数结束前,a 和 b 的引用计数均是 2,随后 b 触发 Drop,此时引用计数会变为 1,并不会归 0,因此 b 所指向内存不会被释放,同理可得 a 指向的内存也不会被释放,最终发生了内存泄漏。
下面一张图很好的展示了这种引用循环关系:

现在我们还需要轻轻的推一下,让塔米诺骨牌轰然倒塌。反注释最后一行代码,试着运行下:
RefCell { value: Cons(5, RefCell { value: Cons(10, RefCell { value: Cons(5, RefCell { value: Cons(10, RefCell { value: Cons(5, RefCell { value: Cons(10, RefCell {
...无穷无尽
thread 'main' has overflowed its stack
fatal runtime error: stack overflow
通过 a.tail 的调用,Rust 试图打印出 a -> b -> a ··· 的所有内容,但是在不懈的努力后,main 线程终于不堪重负,发生了栈溢出。
以上的代码可能并不会造成什么大的问题,但是在一个更加复杂的程序中,类似的问题可能会造成你的程序不断地分配内存、泄漏内存,最终程序会不幸OOM,当然这其中的 CPU 损耗也不可小觑。
总之,创建循环引用并不简单,但是也并不是完全遇不到,当你使用 RefCell<Rc<T>> 或者类似的类型嵌套组合(具备内部可变性和引用计数)时,就要打起万分精神,前面可能是深渊!
那么问题来了? 如果我们确实需要实现上面的功能,该怎么办?答案是使用 Weak。
Weak
Weak 非常类似于 Rc,但是与 Rc 持有所有权不同,Weak 不持有所有权,它仅仅保存一份指向数据的弱引用:如果你想要访问数据,需要通过 Weak 指针的 upgrade 方法实现,该方法返回一个类型为 Option<Rc<T>> 的值。
看到这个返回,相信大家就懂了:何为弱引用?就是不保证引用关系依然存在,如果不存在,就返回一个 None!
因为 Weak 引用不计入所有权,因此它无法阻止所引用的内存值被释放掉,而且 Weak 本身不对值的存在性做任何担保,引用的值还存在就返回 Some,不存在就返回 None。
Weak 与 Rc 对比
我们来将 Weak 与 Rc 进行以下简单对比:
Weak | Rc |
|---|---|
| 不计数 | 引用计数 |
| 不拥有所有权 | 拥有值的所有权 |
| 不阻止值被释放(drop) | 所有权计数归零,才能 drop |
引用的值存在返回 Some,不存在返回 None | 引用的值必定存在 |
通过 upgrade 取到 Option<Rc<T>>,然后再取值 | 通过 Deref 自动解引用,取值无需任何操作 |
通过这个对比,可以非常清晰的看出 Weak 为何这么弱,而这种弱恰恰非常适合我们实现以下的场景:
- 持有一个
Rc对象的临时引用,并且不在乎引用的值是否依然存在 - 阻止
Rc导致的循环引用,因为Rc的所有权机制,会导致多个Rc都无法计数归零
使用方式简单总结下:对于父子引用关系,可以让父节点通过 Rc 来引用子节点,然后让子节点通过 Weak 来引用父节点。
Weak 总结
因为 Weak 本身并不是很好理解,因此我们再来帮大家梳理总结下,然后再通过一个例子,来彻底掌握。
Weak 通过 use std::rc::Weak 来引入,它具有以下特点:
- 可访问,但没有所有权,不增加引用计数,因此不会影响被引用值的释放回收
- 可由
Rc<T>调用downgrade方法转换成Weak<T> Weak<T>可使用upgrade方法转换成Option<Rc<T>>,如果资源已经被释放,则Option的值是None- 常用于解决循环引用的问题
一个简单的例子:
use std::rc::Rc; fn main() { // 创建Rc,持有一个值5 let five = Rc::new(5); // 通过Rc,创建一个Weak指针 let weak_five = Rc::downgrade(&five); // Weak引用的资源依然存在,取到值5 let strong_five: Option<Rc<_>> = weak_five.upgrade(); assert_eq!(*strong_five.unwrap(), 5); // 手动释放资源`five` drop(five); // Weak引用的资源已不存在,因此返回None let strong_five: Option<Rc<_>> = weak_five.upgrade(); assert_eq!(strong_five, None); }
需要承认的是,使用 Weak 让 Rust 本来就堪忧的代码可读性又下降了不少,但是。。。真香,因为可以解决循环引用了。
使用 Weak 解决循环引用
理论知识已经足够,现在用两个例子来模拟下真实场景下可能会遇到的循环引用。
工具间的故事
工具间里,每个工具都有其主人,且多个工具可以拥有一个主人;同时一个主人也可以拥有多个工具,在这种场景下,就很容易形成循环引用,好在我们有 Weak:
use std::rc::Rc; use std::rc::Weak; use std::cell::RefCell; // 主人 struct Owner { name: String, gadgets: RefCell<Vec<Weak<Gadget>>>, } // 工具 struct Gadget { id: i32, owner: Rc<Owner>, } fn main() { // 创建一个 Owner // 需要注意,该 Owner 也拥有多个 `gadgets` let gadget_owner : Rc<Owner> = Rc::new( Owner { name: "Gadget Man".to_string(), gadgets: RefCell::new(Vec::new()), } ); // 创建工具,同时与主人进行关联:创建两个 gadget,他们分别持有 gadget_owner 的一个引用。 let gadget1 = Rc::new(Gadget{id: 1, owner: gadget_owner.clone()}); let gadget2 = Rc::new(Gadget{id: 2, owner: gadget_owner.clone()}); // 为主人更新它所拥有的工具 // 因为之前使用了 `Rc`,现在必须要使用 `Weak`,否则就会循环引用 gadget_owner.gadgets.borrow_mut().push(Rc::downgrade(&gadget1)); gadget_owner.gadgets.borrow_mut().push(Rc::downgrade(&gadget2)); // 遍历 gadget_owner 的 gadgets 字段 for gadget_opt in gadget_owner.gadgets.borrow().iter() { // gadget_opt 是一个 Weak<Gadget> 。 因为 weak 指针不能保证他所引用的对象 // 仍然存在。所以我们需要显式的调用 upgrade() 来通过其返回值(Option<_>)来判 // 断其所指向的对象是否存在。 // 当然,Option 为 None 的时候这个引用原对象就不存在了。 let gadget = gadget_opt.upgrade().unwrap(); println!("Gadget {} owned by {}", gadget.id, gadget.owner.name); } // 在 main 函数的最后,gadget_owner,gadget1 和 gadget2 都被销毁。 // 具体是,因为这几个结构体之间没有了强引用(`Rc<T>`),所以,当他们销毁的时候。 // 首先 gadget2 和 gadget1 被销毁。 // 然后因为 gadget_owner 的引用数量为 0,所以这个对象可以被销毁了。 // 循环引用问题也就避免了 }
tree 数据结构
use std::cell::RefCell; use std::rc::{Rc, Weak}; #[derive(Debug)] struct Node { value: i32, parent: RefCell<Weak<Node>>, children: RefCell<Vec<Rc<Node>>>, } fn main() { let leaf = Rc::new(Node { value: 3, parent: RefCell::new(Weak::new()), children: RefCell::new(vec![]), }); println!( "leaf strong = {}, weak = {}", Rc::strong_count(&leaf), Rc::weak_count(&leaf), ); { let branch = Rc::new(Node { value: 5, parent: RefCell::new(Weak::new()), children: RefCell::new(vec![Rc::clone(&leaf)]), }); *leaf.parent.borrow_mut() = Rc::downgrade(&branch); println!( "branch strong = {}, weak = {}", Rc::strong_count(&branch), Rc::weak_count(&branch), ); println!( "leaf strong = {}, weak = {}", Rc::strong_count(&leaf), Rc::weak_count(&leaf), ); } println!("leaf parent = {:?}", leaf.parent.borrow().upgrade()); println!( "leaf strong = {}, weak = {}", Rc::strong_count(&leaf), Rc::weak_count(&leaf), ); }
这个例子就留给读者自己解读和分析,我们就不画蛇添足了:)
unsafe 解决循环引用
除了使用 Rust 标准库提供的这些类型,你还可以使用 unsafe 里的裸指针来解决这些棘手的问题,但是由于我们还没有讲解 unsafe,因此这里就不进行展开,只附上源码链接, 挺长的,需要耐心 o_o
虽然 unsafe 不安全,但是在各种库的代码中依然很常见用它来实现自引用结构,主要优点如下:
- 性能高,毕竟直接用裸指针操作
- 代码更简单更符合直觉: 对比下
Option<Rc<RefCell<Node>>>
总结
本文深入讲解了何为循环引用以及如何使用 Weak 来解决,同时还结合 Rc、RefCell、Weak 等实现了两个有实战价值的例子,让大家对智能指针的使用更加融会贯通。
至此,智能指针一章即将结束(严格来说还有一个 Mutex 放在多线程一章讲解),而 Rust 语言本身的学习之旅也即将结束,后面我们将深入多线程、项目工程、应用实践、性能分析等特色专题,来一睹 Rust 在这些领域的风采。
结构体自引用
结构体自引用在 Rust 中是一个众所周知的难题,而且众说纷纭,也没有一篇文章能把相关的话题讲透,那本文就王婆卖瓜,来试试看能不能讲透这一块儿内容,让读者大大们舒心。
平平无奇的自引用
可能也有不少人第一次听说自引用结构体,那咱们先来看看它们长啥样。
#![allow(unused)] fn main() { struct SelfRef<'a> { value: String, // 该引用指向上面的value pointer_to_value: &'a str, } }
以上就是一个很简单的自引用结构体,看上去好像没什么,那来试着运行下:
fn main(){ let s = "aaa".to_string(); let v = SelfRef { value: s, pointer_to_value: &s }; }
运行后报错:
let v = SelfRef {
12 | value: s,
| - value moved here
13 | pointer_to_value: &s
| ^^ value borrowed here after move
因为我们试图同时使用值和值的引用,最终所有权转移和借用一起发生了。所以,这个问题貌似并没有那么好解决,不信你可以回想下自己具有的知识,是否可以解决?
使用 Option
最简单的方式就是使用 Option 分两步来实现:
#[derive(Debug)] struct WhatAboutThis<'a> { name: String, nickname: Option<&'a str>, } fn main() { let mut tricky = WhatAboutThis { name: "Annabelle".to_string(), nickname: None, }; tricky.nickname = Some(&tricky.name[..4]); println!("{:?}", tricky); }
在某种程度上来说,Option 这个方法可以工作,但是这个方法的限制较多,例如从一个函数创建并返回它是不可能的:
#![allow(unused)] fn main() { fn creator<'a>() -> WhatAboutThis<'a> { let mut tricky = WhatAboutThis { name: "Annabelle".to_string(), nickname: None, }; tricky.nickname = Some(&tricky.name[..4]); tricky } }
报错如下:
error[E0515]: cannot return value referencing local data `tricky.name`
--> src/main.rs:24:5
|
22 | tricky.nickname = Some(&tricky.name[..4]);
| ----------- `tricky.name` is borrowed here
23 |
24 | tricky
| ^^^^^^ returns a value referencing data owned by the current function
其实从函数签名就能看出来端倪,'a 生命周期是凭空产生的!
如果是通过方法使用,你需要一个无用 &'a self 生命周期标识,一旦有了这个标识,代码将变得更加受限,你将很容易就获得借用错误,就连 NLL 规则都没用:
#[derive(Debug)] struct WhatAboutThis<'a> { name: String, nickname: Option<&'a str>, } impl<'a> WhatAboutThis<'a> { fn tie_the_knot(&'a mut self) { self.nickname = Some(&self.name[..4]); } } fn main() { let mut tricky = WhatAboutThis { name: "Annabelle".to_string(), nickname: None, }; tricky.tie_the_knot(); // cannot borrow `tricky` as immutable because it is also borrowed as mutable // println!("{:?}", tricky); }
unsafe 实现
既然借用规则妨碍了我们,那就一脚踢开:
#[derive(Debug)] struct SelfRef { value: String, pointer_to_value: *const String, } impl SelfRef { fn new(txt: &str) -> Self { SelfRef { value: String::from(txt), pointer_to_value: std::ptr::null(), } } fn init(&mut self) { let self_ref: *const String = &self.value; self.pointer_to_value = self_ref; } fn value(&self) -> &str { &self.value } fn pointer_to_value(&self) -> &String { assert!(!self.pointer_to_value.is_null(), "Test::b called without Test::init being called first"); unsafe { &*(self.pointer_to_value) } } } fn main() { let mut t = SelfRef::new("hello"); t.init(); // 打印值和指针地址 println!("{}, {:p}", t.value(), t.pointer_to_value()); }
在这里,我们在 pointer_to_value 中直接存储裸指针,而不是 Rust 的引用,因此不再受到 Rust 借用规则和生命周期的限制,而且实现起来非常清晰、简洁。但是缺点就是,通过指针获取值时需要使用 unsafe 代码。
当然,上面的代码你还能通过裸指针来修改 String,但是需要将 *const 修改为 *mut:
#[derive(Debug)] struct SelfRef { value: String, pointer_to_value: *mut String, } impl SelfRef { fn new(txt: &str) -> Self { SelfRef { value: String::from(txt), pointer_to_value: std::ptr::null_mut(), } } fn init(&mut self) { let self_ref: *mut String = &mut self.value; self.pointer_to_value = self_ref; } fn value(&self) -> &str { &self.value } fn pointer_to_value(&self) -> &String { assert!(!self.pointer_to_value.is_null(), "Test::b called without Test::init being called first"); unsafe { &*(self.pointer_to_value) } } } fn main() { let mut t = SelfRef::new("hello"); t.init(); println!("{}, {:p}", t.value(), t.pointer_to_value()); t.value.push_str(", world"); unsafe { (&mut *t.pointer_to_value).push_str("!"); } println!("{}, {:p}", t.value(), t.pointer_to_value()); }
运行后输出:
hello, 0x16f3aec70
hello, world!, 0x16f3aec70
上面的 unsafe 虽然简单好用,但是它不太安全,是否还有其他选择?还真的有,那就是 Pin。
无法被移动的 Pin
Pin 在后续章节会深入讲解,目前你只需要知道它可以固定住一个值,防止该值在内存中被移动。
通过开头我们知道,自引用最麻烦的就是创建引用的同时,值的所有权会被转移,而通过 Pin 就可以很好的防止这一点:
use std::marker::PhantomPinned; use std::pin::Pin; use std::ptr::NonNull; // 下面是一个自引用数据结构体,因为 slice 字段是一个指针,指向了 data 字段 // 我们无法使用普通引用来实现,因为违背了 Rust 的编译规则 // 因此,这里我们使用了一个裸指针,通过 NonNull 来确保它不会为 null struct Unmovable { data: String, slice: NonNull<String>, _pin: PhantomPinned, } impl Unmovable { // 为了确保函数返回时数据的所有权不会被转移,我们将它放在堆上,唯一的访问方式就是通过指针 fn new(data: String) -> Pin<Box<Self>> { let res = Unmovable { data, // 只有在数据到位时,才创建指针,否则数据会在开始之前就被转移所有权 slice: NonNull::dangling(), _pin: PhantomPinned, }; let mut boxed = Box::pin(res); let slice = NonNull::from(&boxed.data); // 这里其实安全的,因为修改一个字段不会转移整个结构体的所有权 unsafe { let mut_ref: Pin<&mut Self> = Pin::as_mut(&mut boxed); Pin::get_unchecked_mut(mut_ref).slice = slice; } boxed } } fn main() { let unmoved = Unmovable::new("hello".to_string()); // 只要结构体没有被转移,那指针就应该指向正确的位置,而且我们可以随意移动指针 let mut still_unmoved = unmoved; assert_eq!(still_unmoved.slice, NonNull::from(&still_unmoved.data)); // 因为我们的类型没有实现 `Unpin` 特征,下面这段代码将无法编译 // let mut new_unmoved = Unmovable::new("world".to_string()); // std::mem::swap(&mut *still_unmoved, &mut *new_unmoved); }
上面的代码也非常清晰,虽然使用了 unsafe,其实更多的是无奈之举,跟之前的 unsafe 实现完全不可同日而语。
其实 Pin 在这里并没有魔法,它也并不是实现自引用类型的主要原因,最关键的还是里面的裸指针的使用,而 Pin 起到的作用就是确保我们的值不会被移走,否则指针就会指向一个错误的地址!
使用 ouroboros
对于自引用结构体,三方库也有支持的,其中一个就是 ouroboros,当然它也有自己的限制,我们后面会提到,先来看看该如何使用:
use ouroboros::self_referencing; #[self_referencing] struct SelfRef { value: String, #[borrows(value)] pointer_to_value: &'this str, } fn main(){ let v = SelfRefBuilder { value: "aaa".to_string(), pointer_to_value_builder: |value: &String| value, }.build(); // 借用value值 let s = v.borrow_value(); // 借用指针 let p = v.borrow_pointer_to_value(); // value值和指针指向的值相等 assert_eq!(s, *p); }
可以看到,ouroboros 使用起来并不复杂,就是需要你去按照它的方式创建结构体和引用类型:SelfRef 变成 SelfRefBuilder,引用字段从 pointer_to_value 变成 pointer_to_value_builder,并且连类型都变了。
在使用时,通过 borrow_value 来借用 value 的值,通过 borrow_pointer_to_value 来借用 pointer_to_value 这个指针。
看上去很美好对吧?但是你可以尝试着去修改 String 字符串的值试试,ouroboros 限制还是较多的,但是对于基本类型依然是支持的不错,以下例子来源于官方:
use ouroboros::self_referencing; #[self_referencing] struct MyStruct { int_data: i32, float_data: f32, #[borrows(int_data)] int_reference: &'this i32, #[borrows(mut float_data)] float_reference: &'this mut f32, } fn main() { let mut my_value = MyStructBuilder { int_data: 42, float_data: 3.14, int_reference_builder: |int_data: &i32| int_data, float_reference_builder: |float_data: &mut f32| float_data, }.build(); // Prints 42 println!("{:?}", my_value.borrow_int_data()); // Prints 3.14 println!("{:?}", my_value.borrow_float_reference()); // Sets the value of float_data to 84.0 my_value.with_mut(|fields| { **fields.float_reference = (**fields.int_reference as f32) * 2.0; }); // We can hold on to this reference... let int_ref = *my_value.borrow_int_reference(); println!("{:?}", *int_ref); // As long as the struct is still alive. drop(my_value); // This will cause an error! // println!("{:?}", *int_ref); }
总之,使用这个库前,强烈建议看一些官方的例子中支持什么样的类型和 API,如果能满足的你的需求,就果断使用它,如果不能满足,就继续往下看。
只能说,它确实帮助我们解决了问题,但是一个是破坏了原有的结构,另外就是并不是所有数据类型都支持:它需要目标值的内存地址不会改变,因此 Vec 动态数组就不适合,因为当内存空间不够时,Rust 会重新分配一块空间来存放该数组,这会导致内存地址的改变。
类似的库还有:
- rental, 这个库其实是最有名的,但是好像不再维护了,用倒是没问题
- owning-ref,将所有者和它的引用绑定到一个封装类型
这三个库,各有各的特点,也各有各的缺陷,建议大家需要时,一定要仔细调研,并且写 demo 进行测试,不可大意。
rental 虽然不怎么维护,但是可能依然是这三个里面最强大的,而且网上的用例也比较多,容易找到参考代码
Rc + RefCell 或 Arc + Mutex
类似于循环引用的解决方式,自引用也可以用这种组合来解决,但是会导致代码的类型标识到处都是,大大的影响了可读性。
终极大法
如果两个放在一起会报错,那就分开它们。对,终极大法就这么简单,当然思路上的简单不代表实现上的简单,最终结果就是导致代码复杂度的上升。
学习一本书:如何实现链表
最后,推荐一本专门将如何实现链表的书(真是富有 Rust 特色,链表都能复杂到出书了 o_o),Learn Rust by writing Entirely Too Many Linked Lists
总结
上面讲了这么多方法,但是我们依然无法正确的告诉你在某个场景应该使用哪个方法,这个需要你自己的判断,因为自引用实在是过于复杂。
我们能做的就是告诉你,有这些办法可以解决自引用问题,而这些办法每个都有自己适用的范围,需要你未来去深入的挖掘和发现。
偷偷说一句,就算是我,遇到自引用一样挺头疼,好在这种情况真的不常见,往往是实现特定的算法和数据结构时才需要,应用代码中几乎用不到。
多线程并发编程
安全和高效的处理并发是 Rust 语言的主要目标之一。随着现代处理器的核心数不断增加,并发和并行已经成为日常编程不可或缺的一部分,甚至于 Go 语言已经将并发简化到一个 go 关键字就可以。
可惜的是,在 Rust 中由于语言设计理念、安全、性能的多方面考虑,并没有采用 Go 语言大道至简的方式,而是选择了多线程与 async/await 相结合,优点是可控性更强、性能更高,缺点是复杂度并不低,当然这也是系统级语言的应有选择:使用复杂度换取可控性和性能。
不过,大家也不用担心,本书的目标就是降低 Rust 使用门槛,这个门槛自然也包括如何在 Rust 中进行异步并发编程,我们将从多线程以及 async/await 两个方面去深入浅出地讲解,首先,从本章的多线程开始。
在本章,我们将深入讲解并发和并行的区别以及如何使用多线程进行 Rust 并发编程,那么先来看看何为并行与并发。
并发和并行
并发是同一时间应对多件事情的能力 - Rob Pike
并行和并发其实并不难,但是也给一些用户造成了困扰,因此我们专门开辟一个章节,用于讲清楚这两者的区别。
Erlang 之父 Joe Armstrong(伟大的异步编程先驱,开创一个时代的殿堂级计算机科学家,我还犹记得当年刚学到 Erlang 时的震撼,respect!)用一张 5 岁小孩都能看懂的图片解释了并发与并行的区别:
上图很直观的体现了:
- 并发(Concurrent) 是多个队列使用同一个咖啡机,然后两个队列轮换着使用(未必是 1:1 轮换,也可能是其它轮换规则),最终每个人都能接到咖啡
- 并行(Parallel) 是每个队列都拥有一个咖啡机,最终也是每个人都能接到咖啡,但是效率更高,因为同时可以有两个人在接咖啡
当然,我们还可以对比下串行:只有一个队列且仅使用一台咖啡机,前面哪个人接咖啡时突然发呆了几分钟,后面的人就只能等他结束才能继续接。可能有读者有疑问了,从图片来看,并发也存在这个问题啊,前面的人发呆了几分钟不接咖啡怎么办?很简单,另外一个队列的人把他推开就行了,自己队友不能在背后开枪,但是其它队的可以:)
在正式开始之前,先给出一个结论:并发和并行都是对“多任务”处理的描述,其中并发是轮流处理,而并行是同时处理。
CPU 多核
现在的个人计算机动辄拥有十来个核心(M1 Max/Intel 12 代),如果使用串行的方式那真是太低效了,因此我们把各种任务简单分成多个队列,每个队列都交给一个 CPU 核心去执行,当某个 CPU 核心没有任务时,它还能去其它核心的队列中偷任务(真·老黄牛),这样就实现了并行化处理。
单核心并发
那问题来了,在早期只有一个 CPU 核心时,我们的任务是怎么处理的呢?其实聪明的读者应该已经想到,是的,并发解君愁。当然,这里还得提到操作系统的多线程,正是操作系统多线程 + CPU 核心,才实现了现代化的多任务操作系统。
在 OS 级别,多线程负责管理我们的任务队列,你可以简单认为一个线程管理着一个任务队列,然后线程之间还能根据空闲度进行任务调度。我们的程序只会跟 OS 线程打交道,并不关心 CPU 到底有多少个核心,真正关心的只是 OS,当线程把任务交给 CPU 核心去执行时,如果只有一个 CPU 核心,那么它就只能同时处理一个任务。
相信大家都看出来了:CPU 核心对应的是上图的咖啡机,而多个线程的任务队列就对应的多个排队的队列,由于终受限于 CPU 核心数,每个队列每次只会有一个任务被处理。
和排队一样,假如某个任务执行时间过长,就会导致用户界面的假死(相信使用 Windows 的同学或多或少都碰到过假死的问题), 那么就需要 CPU 的任务调度了(真实 CPU 的调度很复杂,我们这里做了简化),有一个调度器会按照某些条件从队列中选择任务进行执行,并且当一个任务执行时间过长时,会强行切换该任务到后台中(或者放入任务队列,真实情况很复杂!),去执行新的任务。
不断这样的快速任务切换,对用户而言就实现了表面上的多任务同时处理,但是实际上最终也只有一个 CPU 核心在不停的工作。
因此并发的关键在于:快速轮换处理不同的任务,给用户带来所有任务同时在运行的假象。
多核心并行
当 CPU 核心增多到 N 时,那么同一时间就能有 N 个任务被处理,那么我们的并行度就是 N,相应的处理效率也变成了单核心的 N 倍(实际情况并没有这么高)。
多核心并发
当核心增多到 N 时,操作系统同时在进行的任务肯定远不止 N 个,这些任务将被放入 M 个线程队列中,接着交给 N 个 CPU 核心去执行,最后实现了 M:N 的处理模型,在这种情况下,并发与并行是同时在发生的,所有用户任务从表面来看都在并发的运行,但实际上,同一时刻只有 N 个任务能被同时并行的处理。
看到这里,相信大家已经明白两者的区别,那么我们下面给出一个正式的定义(该定义摘选自<<并发的艺术>>)。
正式的定义
如果某个系统支持两个或者多个动作的同时存在,那么这个系统就是一个并发系统。如果某个系统支持两个或者多个动作同时执行,那么这个系统就是一个并行系统。并发系统与并行系统这两个定义之间的关键差异在于 “存在” 这个词。
在并发程序中可以同时拥有两个或者多个线程。这意味着,如果程序在单核处理器上运行,那么这两个线程将交替地换入或者换出内存。这些线程是 同时“存在” 的——每个线程都处于执行过程中的某个状态。如果程序能够并行执行,那么就一定是运行在多核处理器上。此时,程序中的每个线程都将分配到一个独立的处理器核上,因此可以同时运行。
相信你已经能够得出结论——“并行”概念是“并发”概念的一个子集。也就是说,你可以编写一个拥有多个线程或者进程的并发程序,但如果没有多核处理器来执行这个程序,那么就不能以并行方式来运行代码。因此,凡是在求解单个问题时涉及多个执行流程的编程模式或者执行行为,都属于并发编程的范畴。
编程语言的并发模型
如果大家学过其它语言的多线程,可能就知道不同语言对于线程的实现可能大相径庭:
- 由于操作系统提供了创建线程的 API,因此部分语言会直接调用该 API 来创建线程,因此最终程序内的线程数和该程序占用的操作系统线程数相等,一般称之为1:1 线程模型,例如 Rust。
- 还有些语言在内部实现了自己的线程模型(绿色线程、协程),程序内部的 M 个线程最后会以某种映射方式使用 N 个操作系统线程去运行,因此称之为M:N 线程模型,其中 M 和 N 并没有特定的彼此限制关系。一个典型的代表就是 Go 语言。
- 还有些语言使用了 Actor 模型,基于消息传递进行并发,例如 Erlang 语言。
总之,每一种模型都有其优缺点及选择上的权衡,而 Rust 在设计时考虑的权衡就是运行时(Runtime)。出于 Rust 的系统级使用场景,且要保证调用 C 时的极致性能,它最终选择了尽量小的运行时实现。
运行时是那些会被打包到所有程序可执行文件中的 Rust 代码,根据每个语言的设计权衡,运行时虽然有大有小(例如 Go 语言由于实现了协程和 GC,运行时相对就会更大一些),但是除了汇编之外,每个语言都拥有它。小运行时的其中一个好处在于最终编译出的可执行文件会相对较小,同时也让该语言更容易被其它语言引入使用。
而绿色线程/协程的实现会显著增大运行时的大小,因此 Rust 只在标准库中提供了 1:1 的线程模型,如果你愿意牺牲一些性能来换取更精确的线程控制以及更小的线程上下文切换成本,那么可以选择 Rust 中的 M:N 模型,这些模型由三方库提供了实现,例如大名鼎鼎的 tokio。
在了解了并发和并行后,我们可以正式开始 Rust 的多线程之旅。
使用线程
放在十年前,多线程编程可能还是一个少数人才掌握的核心概念,但是在今天,随着编程语言的不断发展,多线程、多协程、Actor 等并发编程方式已经深入人心,同时多线程编程的门槛也在不断降低,本章节我们来看看在 Rust 中该如何使用多线程。
多线程编程的风险
由于多线程的代码是同时运行的,因此我们无法保证线程间的执行顺序,这会导致一些问题:
- 竞态条件(race conditions),多个线程以非一致性的顺序同时访问数据资源
- 死锁(deadlocks),两个线程都想使用某个资源,但是又都在等待对方释放资源后才能使用,结果最终都无法继续执行
- 一些因为多线程导致的很隐晦的 BUG,难以复现和解决
虽然 Rust 已经通过各种机制减少了上述情况的发生,但是依然无法完全避免上述情况,因此我们在编程时需要格外的小心,同时本书也会列出多线程编程时常见的陷阱,让你提前规避可能的风险。
创建线程
使用 thread::spawn 可以创建线程:
use std::thread; use std::time::Duration; fn main() { thread::spawn(|| { for i in 1..10 { println!("hi number {} from the spawned thread!", i); thread::sleep(Duration::from_millis(1)); } }); for i in 1..5 { println!("hi number {} from the main thread!", i); thread::sleep(Duration::from_millis(1)); } }
有几点值得注意:
- 线程内部的代码使用闭包来执行
main线程一旦结束,程序就立刻结束,因此需要保持它的存活,直到其它子线程完成自己的任务thread::sleep会让当前线程休眠指定的时间,随后其它线程会被调度运行(上一节并发与并行中有简单介绍过),因此就算你的电脑只有一个 CPU 核心,该程序也会表现的如同多 CPU 核心一般,这就是并发!
来看看输出:
hi number 1 from the main thread!
hi number 1 from the spawned thread!
hi number 2 from the main thread!
hi number 2 from the spawned thread!
hi number 3 from the main thread!
hi number 3 from the spawned thread!
hi number 4 from the spawned thread!
hi number 4 from the main thread!
hi number 5 from the spawned thread!
如果多运行几次,你会发现好像每次输出会不太一样,因为:虽说线程往往是轮流执行的,但是这一点无法被保证!线程调度的方式往往取决于你使用的操作系统。总之,千万不要依赖线程的执行顺序。
等待子线程的结束
上面的代码你不但可能无法让子线程从 1 顺序打印到 10,而且可能打印的数字会变少,因为主线程会提前结束,导致子线程也随之结束,更过分的是,如果当前系统繁忙,甚至该子线程还没被创建,主线程就已经结束了!
因此我们需要一个方法,让主线程安全、可靠地等所有子线程完成任务后,再 kill self:
use std::thread; use std::time::Duration; fn main() { let handle = thread::spawn(|| { for i in 1..5 { println!("hi number {} from the spawned thread!", i); thread::sleep(Duration::from_millis(1)); } }); handle.join().unwrap(); for i in 1..5 { println!("hi number {} from the main thread!", i); thread::sleep(Duration::from_millis(1)); } }
通过调用 handle.join,可以让当前线程阻塞,直到它等待的子线程的结束,在上面代码中,由于 main 线程会被阻塞,因此它直到子线程结束后才会输出自己的 1..5:
hi number 1 from the spawned thread!
hi number 2 from the spawned thread!
hi number 3 from the spawned thread!
hi number 4 from the spawned thread!
hi number 1 from the main thread!
hi number 2 from the main thread!
hi number 3 from the main thread!
hi number 4 from the main thread!
以上输出清晰的展示了线程阻塞的作用,如果你将 handle.join 放置在 main 线程中的 for 循环后面,那就是另外一个结果:两个线程交替输出。
在线程闭包中使用 move
在闭包章节中,有讲过 move 关键字在闭包中的使用可以让该闭包拿走环境中某个值的所有权,同样地,你可以使用 move 来将所有权从一个线程转移到另外一个线程。
首先,来看看在一个线程中直接使用另一个线程中的数据会如何:
use std::thread; fn main() { let v = vec![1, 2, 3]; let handle = thread::spawn(|| { println!("Here's a vector: {:?}", v); }); handle.join().unwrap(); }
以上代码在子线程的闭包中捕获了环境中的 v 变量,来看看结果:
error[E0373]: closure may outlive the current function, but it borrows `v`, which is owned by the current function
--> src/main.rs:6:32
|
6 | let handle = thread::spawn(|| {
| ^^ may outlive borrowed value `v`
7 | println!("Here's a vector: {:?}", v);
| - `v` is borrowed here
|
note: function requires argument type to outlive `'static`
--> src/main.rs:6:18
|
6 | let handle = thread::spawn(|| {
| __________________^
7 | | println!("Here's a vector: {:?}", v);
8 | | });
| |______^
help: to force the closure to take ownership of `v` (and any other referenced variables), use the `move` keyword
|
6 | let handle = thread::spawn(move || {
| ++++
其实代码本身并没有什么问题,问题在于 Rust 无法确定新的线程会活多久(多个线程的结束顺序并不是固定的),所以也无法确定新线程所引用的 v 是否在使用过程中一直合法:
use std::thread; fn main() { let v = vec![1, 2, 3]; let handle = thread::spawn(|| { println!("Here's a vector: {:?}", v); }); drop(v); // oh no! handle.join().unwrap(); }
大家要记住,线程的启动时间点和结束时间点是不确定的,因此存在一种可能,当主线程执行完, v 被释放掉时,新的线程很可能还没有结束甚至还没有被创建成功,此时新线程对 v 的引用立刻就不再合法!
好在报错里进行了提示:to force the closure to take ownership of v (and any other referenced variables), use the `move` keyword,让我们使用 move 关键字拿走 v 的所有权即可:
use std::thread; fn main() { let v = vec![1, 2, 3]; let handle = thread::spawn(move || { println!("Here's a vector: {:?}", v); }); handle.join().unwrap(); // 下面代码会报错borrow of moved value: `v` // println!("{:?}",v); }
如上所示,很简单的代码,而且 Rust 的所有权机制保证了数据使用上的安全:v 的所有权被转移给新的线程后,main 线程将无法继续使用:最后一行代码将报错。
线程是如何结束的
之前我们提到 main 线程是程序的主线程,一旦结束,则程序随之结束,同时各个子线程也将被强行终止。那么有一个问题,如果父线程不是 main 线程,那么父线程的结束会导致什么?自生自灭还是被干掉?
在系统编程中,操作系统提供了直接杀死线程的接口,简单粗暴,但是 Rust 并没有提供这样的接口,原因在于,粗暴地终止一个线程可能会导致资源没有释放、状态混乱等不可预期的结果,一向以安全自称的 Rust,自然不会砸自己的饭碗。
那么 Rust 中线程是如何结束的呢?答案很简单:线程的代码执行完,线程就会自动结束。但是如果线程中的代码不会执行完呢?那么情况可以分为两种进行讨论:
- 线程的任务是一个循环 IO 读取,任务流程类似:IO 阻塞,等待读取新的数据 -> 读到数据,处理完成 -> 继续阻塞等待 ··· -> 收到 socket 关闭的信号 -> 结束线程,在此过程中,绝大部分时间线程都处于阻塞的状态,因此虽然看上去是循环,CPU 占用其实很小,也是网络服务中最最常见的模型
- 线程的任务是一个循环,里面没有任何阻塞,包括休眠这种操作也没有,此时 CPU 很不幸的会被跑满,而且你如果没有设置终止条件,该线程将持续跑满一个 CPU 核心,并且不会被终止,直到
main线程的结束
第一情况很常见,我们来模拟看看第二种情况:
use std::thread; use std::time::Duration; fn main() { // 创建一个线程A let new_thread = thread::spawn(move || { // 再创建一个线程B thread::spawn(move || { loop { println!("I am a new thread."); } }) }); // 等待新创建的线程执行完成 new_thread.join().unwrap(); println!("Child thread is finish!"); // 睡眠一段时间,看子线程创建的子线程是否还在运行 thread::sleep(Duration::from_millis(100)); }
以上代码中,main 线程创建了一个新的线程 A,同时该新线程又创建了一个新的线程 B,可以看到 A 线程在创建完 B 线程后就立即结束了,而 B 线程则在不停地循环输出。
从之前的线程结束规则,我们可以猜测程序将这样执行:A 线程结束后,由它创建的 B 线程仍在疯狂输出,直到 main 线程在 100 毫秒后结束。如果你把该时间增加到几十秒,就可以看到你的 CPU 核心 100% 的盛况了-,-
多线程的性能
下面我们从多个方面来看看多线程的性能大概是怎么样的。
创建线程的性能
据不精确估算,创建一个线程大概需要 0.24 毫秒,随着线程的变多,这个值会变得更大,因此线程的创建耗时并不是不可忽略的,只有当真的需要处理一个值得用线程去处理的任务时,才使用线程,一些鸡毛蒜皮的任务,就无需创建线程了。
创建多少线程合适
因为 CPU 的核心数限制,当任务是 CPU 密集型时,就算线程数超过了 CPU 核心数,也并不能帮你获得更好的性能,因为每个线程的任务都可以轻松让 CPU 的某个核心跑满,既然如此,让线程数等于 CPU 核心数是最好的。
但是当你的任务大部分时间都处于阻塞状态时,就可以考虑增多线程数量,这样当某个线程处于阻塞状态时,会被切走,进而运行其它的线程,典型就是网络 IO 操作,我们可以为每一个进来的用户连接创建一个线程去处理,该连接绝大部分时间都是处于 IO 读取阻塞状态,因此有限的 CPU 核心完全可以处理成百上千的用户连接线程,但是事实上,对于这种网络 IO 情况,一般都不再使用多线程的方式了,毕竟操作系统的线程数是有限的,意味着并发数也很容易达到上限,而且过多的线程也会导致线程上下文切换的代价过大,使用 async/await 的 M:N 并发模型,就没有这个烦恼。
多线程的开销
下面的代码是一个无锁实现(CAS)的 Hashmap 在多线程下的使用:
#![allow(unused)] fn main() { for i in 0..num_threads { let ht = Arc::clone(&ht); let handle = thread::spawn(move || { for j in 0..adds_per_thread { let key = thread_rng().gen::<u32>(); let value = thread_rng().gen::<u32>(); ht.set_item(key, value); } }); handles.push(handle); } for handle in handles { handle.join().unwrap(); } }
按理来说,既然是无锁实现了,那么锁的开销应该几乎没有,性能会随着线程数的增加接近线性增长,但是真的是这样吗?
下图是该代码在 48 核机器上的运行结果:
从图上可以明显的看出:吞吐并不是线性增长,尤其从 16 核开始,甚至开始肉眼可见的下降,这是为什么呢?
限于书本的篇幅有限,我们只能给出大概的原因:
- 虽然是无锁,但是内部是 CAS 实现,大量线程的同时访问,会让 CAS 重试次数大幅增加
- 线程过多时,CPU 缓存的命中率会显著下降,同时多个线程竞争一个 CPU Cache-line 的情况也会经常发生
- 大量读写可能会让内存带宽也成为瓶颈
- 读和写不一样,无锁数据结构的读往往可以很好地线性增长,但是写不行,因为写竞争太大
总之,多线程的开销往往是在锁、数据竞争、缓存失效上,这些限制了现代化软件系统随着 CPU 核心的增多性能也线性增加的野心。
线程屏障(Barrier)
在 Rust 中,可以使用 Barrier 让多个线程都执行到某个点后,才继续一起往后执行:
use std::sync::{Arc, Barrier}; use std::thread; fn main() { let mut handles = Vec::with_capacity(6); let barrier = Arc::new(Barrier::new(6)); for _ in 0..6 { let b = barrier.clone(); handles.push(thread::spawn(move|| { println!("before wait"); b.wait(); println!("after wait"); })); } for handle in handles { handle.join().unwrap(); } }
上面代码,我们在线程打印出 before wait 后增加了一个屏障,目的就是等所有的线程都打印出before wait后,各个线程再继续执行:
before wait
before wait
before wait
before wait
before wait
before wait
after wait
after wait
after wait
after wait
after wait
after wait
线程局部变量(Thread Local Variable)
对于多线程编程,线程局部变量在一些场景下非常有用,而 Rust 通过标准库和三方库对此进行了支持。
标准库 thread_local
使用 thread_local 宏可以初始化线程局部变量,然后在线程内部使用该变量的 with 方法获取变量值:
#![allow(unused)] fn main() { use std::cell::RefCell; use std::thread; thread_local!(static FOO: RefCell<u32> = RefCell::new(1)); FOO.with(|f| { assert_eq!(*f.borrow(), 1); *f.borrow_mut() = 2; }); // 每个线程开始时都会拿到线程局部变量的FOO的初始值 let t = thread::spawn(move|| { FOO.with(|f| { assert_eq!(*f.borrow(), 1); *f.borrow_mut() = 3; }); }); // 等待线程完成 t.join().unwrap(); // 尽管子线程中修改为了3,我们在这里依然拥有main线程中的局部值:2 FOO.with(|f| { assert_eq!(*f.borrow(), 2); }); }
上面代码中,FOO 即是我们创建的线程局部变量,每个新的线程访问它时,都会使用它的初始值作为开始,各个线程中的 FOO 值彼此互不干扰。注意 FOO 使用 static 声明为生命周期为 'static 的静态变量。
可以注意到,线程中对 FOO 的使用是通过借用的方式,但是若我们需要每个线程独自获取它的拷贝,最后进行汇总,就有些强人所难了。
你还可以在结构体中使用线程局部变量:
use std::cell::RefCell; struct Foo; impl Foo { thread_local! { static FOO: RefCell<usize> = RefCell::new(0); } } fn main() { Foo::FOO.with(|x| println!("{:?}", x)); }
或者通过引用的方式使用它:
#![allow(unused)] fn main() { use std::cell::RefCell; use std::thread::LocalKey; thread_local! { static FOO: RefCell<usize> = RefCell::new(0); } struct Bar { foo: &'static LocalKey<RefCell<usize>>, } impl Bar { fn constructor() -> Self { Self { foo: &FOO, } } } }
三方库 thread-local
除了标准库外,一位大神还开发了 thread-local 库,它允许每个线程持有值的独立拷贝:
#![allow(unused)] fn main() { use thread_local::ThreadLocal; use std::sync::Arc; use std::cell::Cell; use std::thread; let tls = Arc::new(ThreadLocal::new()); // 创建多个线程 for _ in 0..5 { let tls2 = tls.clone(); thread::spawn(move || { // 将计数器加1 let cell = tls2.get_or(|| Cell::new(0)); cell.set(cell.get() + 1); }).join().unwrap(); } // 一旦所有子线程结束,收集它们的线程局部变量中的计数器值,然后进行求和 let tls = Arc::try_unwrap(tls).unwrap(); let total = tls.into_iter().fold(0, |x, y| x + y.get()); // 和为5 assert_eq!(total, 5); }
该库不仅仅使用了值的拷贝,而且还能自动把多个拷贝汇总到一个迭代器中,最后进行求和,非常好用。
用条件控制线程的挂起和执行
条件变量(Condition Variables)经常和 Mutex 一起使用,可以让线程挂起,直到某个条件发生后再继续执行:
use std::thread; use std::sync::{Arc, Mutex, Condvar}; fn main() { let pair = Arc::new((Mutex::new(false), Condvar::new())); let pair2 = pair.clone(); thread::spawn(move|| { let &(ref lock, ref cvar) = &*pair2; let mut started = lock.lock().unwrap(); println!("changing started"); *started = true; cvar.notify_one(); }); let &(ref lock, ref cvar) = &*pair; let mut started = lock.lock().unwrap(); while !*started { started = cvar.wait(started).unwrap(); } println!("started changed"); }
上述代码流程如下:
main线程首先进入while循环,调用wait方法挂起等待子线程的通知,并释放了锁started- 子线程获取到锁,并将其修改为
true,然后调用条件变量的notify_one方法来通知主线程继续执行
只被调用一次的函数
有时,我们会需要某个函数在多线程环境下只被调用一次,例如初始化全局变量,无论是哪个线程先调用函数来初始化,都会保证全局变量只会被初始化一次,随后的其它线程调用就会忽略该函数:
use std::thread; use std::sync::Once; static mut VAL: usize = 0; static INIT: Once = Once::new(); fn main() { let handle1 = thread::spawn(move || { INIT.call_once(|| { unsafe { VAL = 1; } }); }); let handle2 = thread::spawn(move || { INIT.call_once(|| { unsafe { VAL = 2; } }); }); handle1.join().unwrap(); handle2.join().unwrap(); println!("{}", unsafe { VAL }); }
代码运行的结果取决于哪个线程先调用 INIT.call_once (虽然代码具有先后顺序,但是线程的初始化顺序并无法被保证!因为线程初始化是异步的,且耗时较久),若 handle1 先,则输出 1,否则输出 2。
call_once 方法
执行初始化过程一次,并且只执行一次。
如果当前有另一个初始化过程正在运行,线程将阻止该方法被调用。
当这个函数返回时,保证一些初始化已经运行并完成,它还保证由执行的闭包所执行的任何内存写入都能被其他线程在这时可靠地观察到。
总结
Rust 的线程模型是 1:1 模型,因为 Rust 要保持尽量小的运行时。
我们可以使用 thread::spawn 来创建线程,创建出的多个线程之间并不存在执行顺序关系,因此代码逻辑千万不要依赖于线程间的执行顺序。
main 线程若是结束,则所有子线程都将被终止,如果希望等待子线程结束后,再结束 main 线程,你需要使用创建线程时返回的句柄的 join 方法。
在线程中无法直接借用外部环境中的变量值,因为新线程的启动时间点和结束时间点是不确定的,所以 Rust 无法保证该线程中借用的变量在使用过程中依然是合法的。你可以使用 move 关键字将变量的所有权转移给新的线程,来解决此问题。
父线程结束后,子线程仍在持续运行,直到子线程的代码运行完成或者 main 线程的结束。
线程间的消息传递
在多线程间有多种方式可以共享、传递数据,最常用的方式就是通过消息传递或者将锁和Arc联合使用,而对于前者,在编程界还有一个大名鼎鼎的Actor线程模型为其背书,典型的有 Erlang 语言,还有 Go 语言中很经典的一句话:
Do not communicate by sharing memory; instead, share memory by communicating
而对于后者,我们将在下一节中进行讲述。
消息通道
与 Go 语言内置的chan不同,Rust 是在标准库里提供了消息通道(channel),你可以将其想象成一场直播,多个主播联合起来在搞一场直播,最终内容通过通道传输给屏幕前的我们,其中主播被称之为发送者,观众被称之为接收者,显而易见的是:一个通道应该支持多个发送者和接收者。
但是,在实际使用中,我们需要使用不同的库来满足诸如:多发送者 -> 单接收者,多发送者 -> 多接收者等场景形式,此时一个标准库显然就不够了,不过别急,让我们先从标准库讲起。
多发送者,单接收者
标准库提供了通道std::sync::mpsc,其中mpsc是multiple producer, single consumer的缩写,代表了该通道支持多个发送者,但是只支持唯一的接收者。 当然,支持多个发送者也意味着支持单个发送者,我们先来看看单发送者、单接收者的简单例子:
use std::sync::mpsc; use std::thread; fn main() { // 创建一个消息通道, 返回一个元组:(发送者,接收者) let (tx, rx) = mpsc::channel(); // 创建线程,并发送消息 thread::spawn(move || { // 发送一个数字1, send方法返回Result<T,E>,通过unwrap进行快速错误处理 tx.send(1).unwrap(); // 下面代码将报错,因为编译器自动推导出通道传递的值是i32类型,那么Option<i32>类型将产生不匹配错误 // tx.send(Some(1)).unwrap() }); // 在主线程中接收子线程发送的消息并输出 println!("receive {}", rx.recv().unwrap()); }
以上代码并不复杂,但仍有几点需要注意:
tx,rx对应发送者和接收者,它们的类型由编译器自动推导:tx.send(1)发送了整数,因此它们分别是mpsc::Sender<i32>和mpsc::Receiver<i32>类型,需要注意,由于内部是泛型实现,一旦类型被推导确定,该通道就只能传递对应类型的值, 例如此例中非i32类型的值将导致编译错误- 接收消息的操作
rx.recv()会阻塞当前线程,直到读取到值,或者通道被关闭 - 需要使用
move将tx的所有权转移到子线程的闭包中
在注释中提到send方法返回一个Result<T,E>,说明它有可能返回一个错误,例如接收者被drop导致了发送的值不会被任何人接收,此时继续发送毫无意义,因此返回一个错误最为合适,在代码中我们仅仅使用unwrap进行了快速处理,但在实际项目中你需要对错误进行进一步的处理。
同样的,对于recv方法来说,当发送者关闭时,它也会接收到一个错误,用于说明不会再有任何值被发送过来。
不阻塞的 try_recv 方法
除了上述recv方法,还可以使用try_recv尝试接收一次消息,该方法并不会阻塞线程,当通道中没有消息时,它会立刻返回一个错误:
use std::sync::mpsc; use std::thread; fn main() { let (tx, rx) = mpsc::channel(); thread::spawn(move || { tx.send(1).unwrap(); }); println!("receive {:?}", rx.try_recv()); }
由于子线程的创建需要时间,因此println!和try_recv方法会先执行,而此时子线程的消息还未被发出。try_recv会尝试立即读取一次消息,因为消息没有发出,此次读取最终会报错,且主线程运行结束(可悲的是,相对于主线程中的代码,子线程的创建速度实在是过慢,直到主线程结束,都无法完成子线程的初始化。。):
receive Err(Empty)
如上,try_recv返回了一个错误,错误内容是Empty,代表通道并没有消息。如果你尝试把println!复制一些行,就会发现一个有趣的输出:
···
receive Err(Empty)
receive Ok(1)
receive Err(Disconnected)
···
如上,当子线程创建成功且发送消息后,主线程会接收到Ok(1)的消息内容,紧接着子线程结束,发送者也随着被drop,此时接收者又会报错,但是这次错误原因有所不同:Disconnected代表发送者已经被关闭。
传输具有所有权的数据
使用通道来传输数据,一样要遵循 Rust 的所有权规则:
- 若值的类型实现了
Copy特征,则直接复制一份该值,然后传输过去,例如之前的i32类型 - 若值没有实现
Copy,则它的所有权会被转移给接收端,在发送端继续使用该值将报错
一起来看看第二种情况:
use std::sync::mpsc; use std::thread; fn main() { let (tx, rx) = mpsc::channel(); thread::spawn(move || { let s = String::from("我,飞走咯!"); tx.send(s).unwrap(); println!("val is {}", s); }); let received = rx.recv().unwrap(); println!("Got: {}", received); }
以上代码中,String底层的字符串是存储在堆上,并没有实现Copy特征,当它被发送后,会将所有权从发送端的s转移给接收端的received,之后s将无法被使用:
error[E0382]: borrow of moved value: `s`
--> src/main.rs:10:31
|
8 | let s = String::from("我,飞走咯!");
| - move occurs because `s` has type `String`, which does not implement the `Copy` trait // 所有权被转移,由于`String`没有实现`Copy`特征
9 | tx.send(s).unwrap();
| - value moved here // 所有权被转移走
10 | println!("val is {}", s);
| ^ value borrowed here after move // 所有权被转移后,依然对s进行了借用
各种细节不禁令人感叹:Rust 还是安全!假如没有所有权的保护,String字符串将被两个线程同时持有,任何一个线程对字符串内容的修改都会导致另外一个线程持有的字符串被改变,除非你故意这么设计,否则这就是不安全的隐患。
使用 for 进行循环接收
下面来看看如何连续接收通道中的值:
use std::sync::mpsc; use std::thread; use std::time::Duration; fn main() { let (tx, rx) = mpsc::channel(); thread::spawn(move || { let vals = vec![ String::from("hi"), String::from("from"), String::from("the"), String::from("thread"), ]; for val in vals { tx.send(val).unwrap(); thread::sleep(Duration::from_secs(1)); } }); for received in rx { println!("Got: {}", received); } }
在上面代码中,主线程和子线程是并发运行的,子线程在不停的发送消息 -> 休眠 1 秒,与此同时,主线程使用for循环阻塞的从rx迭代器中接收消息,当子线程运行完成时,发送者tx会随之被drop,此时for循环将被终止,最终main线程成功结束。
使用多发送者
由于子线程会拿走发送者的所有权,因此我们必须对发送者进行克隆,然后让每个线程拿走它的一份拷贝:
use std::sync::mpsc; use std::thread; fn main() { let (tx, rx) = mpsc::channel(); let tx1 = tx.clone(); thread::spawn(move || { tx.send(String::from("hi from raw tx")).unwrap(); }); thread::spawn(move || { tx1.send(String::from("hi from cloned tx")).unwrap(); }); for received in rx { println!("Got: {}", received); } }
代码并无太大区别,就多了一个对发送者的克隆let tx1 = tx.clone();,然后一个子线程拿走tx的所有权,另一个子线程拿走tx1的所有权,皆大欢喜。
但是有几点需要注意:
- 需要所有的发送者都被
drop掉后,接收者rx才会收到错误,进而跳出for循环,最终结束主线程 - 这里虽然用了
clone但是并不会影响性能,因为它并不在热点代码路径中,仅仅会被执行一次 - 由于两个子线程谁先创建完成是未知的,因此哪条消息先发送也是未知的,最终主线程的输出顺序也不确定
消息顺序
上述第三点的消息顺序仅仅是因为线程创建引起的,并不代表通道中的消息是无序的,对于通道而言,消息的发送顺序和接收顺序是一致的,满足FIFO原则(先进先出)。
由于篇幅有限,具体的代码这里就不再给出,感兴趣的读者可以自己验证下。
同步和异步通道
Rust 标准库的mpsc通道其实分为两种类型:同步和异步。
异步通道
之前我们使用的都是异步通道:无论接收者是否正在接收消息,消息发送者在发送消息时都不会阻塞:
use std::sync::mpsc; use std::thread; use std::time::Duration; fn main() { let (tx, rx)= mpsc::channel(); let handle = thread::spawn(move || { println!("发送之前"); tx.send(1).unwrap(); println!("发送之后"); }); println!("睡眠之前"); thread::sleep(Duration::from_secs(3)); println!("睡眠之后"); println!("receive {}", rx.recv().unwrap()); handle.join().unwrap(); }
运行后输出如下:
睡眠之前
发送之前
发送之后
//···睡眠3秒
睡眠之后
receive 1
主线程因为睡眠阻塞了 3 秒,因此并没有进行消息接收,而子线程却在此期间轻松完成了消息的发送。等主线程睡眠结束后,才姗姗来迟的从通道中接收了子线程老早之前发送的消息。
从输出还可以看出,发送之前和发送之后是连续输出的,没有受到接收端主线程的任何影响,因此通过mpsc::channel创建的通道是异步通道。
同步通道
与异步通道相反,同步通道发送消息是阻塞的,只有在消息被接收后才解除阻塞,例如:
use std::sync::mpsc; use std::thread; use std::time::Duration; fn main() { let (tx, rx)= mpsc::sync_channel(0); let handle = thread::spawn(move || { println!("发送之前"); tx.send(1).unwrap(); println!("发送之后"); }); println!("睡眠之前"); thread::sleep(Duration::from_secs(3)); println!("睡眠之后"); println!("receive {}", rx.recv().unwrap()); handle.join().unwrap(); }
运行后输出如下:
睡眠之前
发送之前
//···睡眠3秒
睡眠之后
receive 1
发送之后
可以看出,主线程由于睡眠被阻塞导致无法接收消息,因此子线程的发送也一直被阻塞,直到主线程结束睡眠并成功接收消息后,发送才成功:发送之后的输出是在receive 1之后,说明只有接收消息彻底成功后,发送消息才算完成。
消息缓存
细心的读者可能已经发现在创建同步通道时,我们传递了一个参数0: mpsc::sync_channel(0);,这是什么意思呢?
答案不急给出,先将0改成1,然后再运行试试:
睡眠之前
发送之前
发送之后
睡眠之后
receive 1
纳尼。。竟然得到了和异步通道一样的效果:根本没有等待主线程的接收开始,消息发送就立即完成了! 难道同步通道变成了异步通道? 别急,将子线程中的代码修改下试试:
#![allow(unused)] fn main() { println!("首次发送之前"); tx.send(1).unwrap(); println!("首次发送之后"); tx.send(1).unwrap(); println!("再次发送之后"); }
在子线程中,我们又多发了一条消息,此时输出如下:
睡眠之前
首次发送之前
首次发送之后
//···睡眠3秒
睡眠之后
receive 1
再次发送之后
Bingo,更奇怪的事出现了,第一条消息瞬间发送完成,没有阻塞,而发送第二条消息时却符合同步通道的特点:阻塞了,直到主线程接收后,才发送完成。
其实,一切的关键就在于1上,该值可以用来指定同步通道的消息缓存条数,当你设定为N时,发送者就可以无阻塞的往通道中发送N条消息,当消息缓冲队列满了后,新的消息发送将被阻塞(如果没有接收者消费缓冲队列中的消息,那么第N+1条消息就将触发发送阻塞)。
问题又来了,异步通道创建时完全没有这个缓冲值参数mpsc::channel(),它的缓冲值怎么设置呢? 额。。。都异步了,都可以无限发送了,都有摩托车了,还要自行车做啥子哦?事实上异步通道的缓冲上限取决于你的内存大小,不要撑爆就行。
因此,使用异步消息虽然能非常高效且不会造成发送线程的阻塞,但是存在消息未及时消费,最终内存过大的问题。在实际项目中,可以考虑使用一个带缓冲值的同步通道来避免这种风险。
关闭通道
之前我们数次提到了通道关闭,并且提到了当通道关闭后,发送消息或接收消息将会报错。那么如何关闭通道呢? 很简单:所有发送者被drop或者所有接收者被drop后,通道会自动关闭。
神奇的是,这件事是在编译期实现的,完全没有运行期性能损耗!只能说 Rust 的Drop特征 YYDS!
传输多种类型的数据
之前提到过,一个消息通道只能传输一种类型的数据,如果你想要传输多种类型的数据,可以为每个类型创建一个通道,你也可以使用枚举类型来实现:
use std::sync::mpsc::{self, Receiver, Sender}; enum Fruit { Apple(u8), Orange(String) } fn main() { let (tx, rx): (Sender<Fruit>, Receiver<Fruit>) = mpsc::channel(); tx.send(Fruit::Orange("sweet".to_string())).unwrap(); tx.send(Fruit::Apple(2)).unwrap(); for _ in 0..2 { match rx.recv().unwrap() { Fruit::Apple(count) => println!("received {} apples", count), Fruit::Orange(flavor) => println!("received {} oranges", flavor), } } }
如上所示,枚举类型还能让我们带上想要传输的数据,但是有一点需要注意,Rust 会按照枚举中占用内存最大的那个成员进行内存对齐,这意味着就算你传输的是枚举中占用内存最小的成员,它占用的内存依然和最大的成员相同, 因此会造成内存上的浪费。
新手容易遇到的坑
mpsc虽然相当简洁明了,但是在使用起来还是可能存在坑:
use std::sync::mpsc; fn main() { use std::thread; let (send, recv) = mpsc::channel(); let num_threads = 3; for i in 0..num_threads { let thread_send = send.clone(); thread::spawn(move || { thread_send.send(i).unwrap(); println!("thread {:?} finished", i); }); } // 在这里drop send... for x in recv { println!("Got: {}", x); } println!("finished iterating"); }
以上代码看起来非常正常,但是运行后主线程会一直阻塞,最后一行打印输出也不会被执行,原因在于: 子线程拿走的是复制后的send的所有权,这些拷贝会在子线程结束后被drop,因此无需担心,但是send本身却直到main函数的结束才会被drop。
之前提到,通道关闭的两个条件:发送者全部drop或接收者被drop,要结束for循环显然是要求发送者全部drop,但是由于send自身没有被drop,会导致该循环永远无法结束,最终主线程会一直阻塞。
解决办法很简单,drop掉send即可:在代码中的注释下面添加一行drop(send);。
mpmc 更好的性能
如果你需要 mpmc(多发送者,多接收者)或者需要更高的性能,可以考虑第三方库:
- crossbeam-channel, 老牌强库,功能较全,性能较强,之前是独立的库,但是后面合并到了
crossbeam主仓库中 - flume, 官方给出的性能数据某些场景要比 crossbeam 更好些
线程同步:锁、Condvar 和信号量
在多线程编程中,同步性极其的重要,当你需要同时访问一个资源、控制不同线程的执行次序时,都需要使用到同步性。
在 Rust 中有多种方式可以实现同步性。在上一节中讲到的消息传递就是同步性的一种实现方式,例如我们可以通过消息传递来控制不同线程间的执行次序。还可以使用共享内存来实现同步性,例如通过锁和原子操作等并发原语来实现多个线程同时且安全地去访问一个资源。
该如何选择
共享内存可以说是同步的灵魂,因为消息传递的底层实际上也是通过共享内存来实现,两者的区别如下:
- 共享内存相对消息传递能节省多次内存拷贝的成本
- 共享内存的实现简洁的多
- 共享内存的锁竞争更多
消息传递适用的场景很多,我们下面列出了几个主要的使用场景:
- 需要可靠和简单的(简单不等于简洁)实现时
- 需要模拟现实世界,例如用消息去通知某个目标执行相应的操作时
- 需要一个任务处理流水线(管道)时,等等
而使用共享内存(并发原语)的场景往往就比较简单粗暴:需要简洁的实现以及更高的性能时。
总之,消息传递类似一个单所有权的系统:一个值同时只能有一个所有者,如果另一个线程需要该值的所有权,需要将所有权通过消息传递进行转移。而共享内存类似于一个多所有权的系统:多个线程可以同时访问同一个值。
互斥锁 Mutex
既然是共享内存,那并发原语自然是重中之重,先来一起看看皇冠上的明珠: 互斥锁Mutex(mutual exclusion 的缩写)。
Mutex让多个线程并发的访问同一个值变成了排队访问:同一时间,只允许一个线程A访问该值,其它线程需要等待A访问完成后才能继续。
单线程中使用 Mutex
先来看看单线程中Mutex该如何使用:
use std::sync::Mutex; fn main() { // 使用`Mutex`结构体的关联函数创建新的互斥锁实例 let m = Mutex::new(5); { // 获取锁,然后deref为`m`的引用 // lock返回的是Result let mut num = m.lock().unwrap(); *num = 6; // 锁自动被drop } println!("m = {:?}", m); }
在注释中,已经大致描述了代码的功能,不过有一点需要注意:和Box类似,数据被Mutex所拥有,要访问内部的数据,需要使用方法m.lock()向m申请一个锁, 该方法会阻塞当前线程,直到获取到锁,因此当多个线程同时访问该数据时,只有一个线程能获取到锁,其它线程只能阻塞着等待,这样就保证了数据能被安全的修改!
m.lock()方法也有可能报错,例如当前正在持有锁的线程panic了。在这种情况下,其它线程不可能再获得锁,因此lock方法会返回一个错误。
这里你可能奇怪,m.lock明明返回一个锁,怎么就变成我们的num数值了?聪明的读者可能会想到智能指针,没错,因为Mutex<T>是一个智能指针,准确的说是m.lock()返回一个智能指针MutexGuard<T>:
- 它实现了
Deref特征,会被自动解引用后获得一个引用类型,该引用指向Mutex内部的数据 - 它还实现了
Drop特征,在超出作用域后,自动释放锁,以便其它线程能继续获取锁
正因为智能指针的使用,使得我们无需任何操作就能获取其中的数据。 如果释放锁,你需要做的仅仅是做好锁的作用域管理,例如上述代码的内部花括号使用,建议读者尝试下去掉内部的花括号,然后再次尝试获取第二个锁num1,看看会发生什么,友情提示:不会报错,但是主线程会永远阻塞,因为不幸发生了死锁。
use std::sync::Mutex; fn main() { let m = Mutex::new(5); let mut num = m.lock().unwrap(); *num = 6; // 锁还没有被 drop 就尝试申请下一个锁,导致主线程阻塞 // drop(num); // 手动 drop num ,可以让 num1 申请到下个锁 let mut num1 = m.lock().unwrap(); *num1 = 7; // drop(num1); // 手动 drop num1 ,观察打印结果的不同 println!("m = {:?}", m); }
多线程中使用 Mutex
单线程中使用锁,说实话纯粹是为了演示功能,毕竟多线程才是锁的舞台。 现在,我们再来看看,如何在多线程下使用Mutex来访问同一个资源.
无法运行的Rc<T>
use std::rc::Rc; use std::sync::Mutex; use std::thread; fn main() { // 通过`Rc`实现`Mutex`的多所有权 let counter = Rc::new(Mutex::new(0)); let mut handles = vec![]; for _ in 0..10 { let counter = Rc::clone(&counter); // 创建子线程,并将`Mutex`的所有权拷贝传入到子线程中 let handle = thread::spawn(move || { let mut num = counter.lock().unwrap(); *num += 1; }); handles.push(handle); } // 等待所有子线程完成 for handle in handles { handle.join().unwrap(); } // 输出最终的计数结果 println!("Result: {}", *counter.lock().unwrap()); }
由于子线程需要通过move拿走锁的所有权,因此我们需要使用多所有权来保证每个线程都拿到数据的独立所有权,恰好智能指针Rc<T>可以做到(上面代码会报错!具体往下看,别跳过-, -)。
以上代码实现了在多线程中计数的功能,由于多个线程都需要去修改该计数器,因此我们需要使用锁来保证同一时间只有一个线程可以修改计数器,否则会导致脏数据:想象一下 A 线程和 B 线程同时拿到计数器,获取了当前值1, 并且同时对其进行了修改,最后值变成2,你会不会在风中凌乱?毕竟正确的值是3,因为两个线程各自加 1。
可能有人会说,有那么巧的事情吗?事实上,对于人类来说,因为干啥啥慢,并没有那么多巧合,所以人总会存在巧合心理。但是对于计算机而言,每秒可以轻松运行上亿次,在这种频次下,一切巧合几乎都将必然发生,因此千万不要有任何侥幸心理。
如果事情有变坏的可能,不管这种可能性有多小,它都会发生! - 在计算机领域歪打正着的墨菲定律
事实上,上面的代码会报错:
error[E0277]: `Rc<Mutex<i32>>` cannot be sent between threads safely
// `Rc`无法在线程中安全的传输
--> src/main.rs:11:22
|
13 | let handle = thread::spawn(move || {
| ______________________^^^^^^^^^^^^^_-
| | |
| | `Rc<Mutex<i32>>` cannot be sent between threads safely
14 | | let mut num = counter.lock().unwrap();
15 | |
16 | | *num += 1;
17 | | });
| |_________- within this `[closure@src/main.rs:11:36: 15:10]`
|
= help: within `[closure@src/main.rs:11:36: 15:10]`, the trait `Send` is not implemented for `Rc<Mutex<i32>>`
// `Rc`没有实现`Send`特征
= note: required because it appears within the type `[closure@src/main.rs:11:36: 15:10]`
错误中提到了一个关键点:Rc<T>无法在线程中传输,因为它没有实现Send特征(在下一节将详细介绍),而该特征可以确保数据在线程中安全的传输。
多线程安全的 Arc<T>
好在,我们有Arc<T>,得益于它的内部计数器是多线程安全的,因此可以在多线程环境中使用:
use std::sync::{Arc, Mutex}; use std::thread; fn main() { let counter = Arc::new(Mutex::new(0)); let mut handles = vec![]; for _ in 0..10 { let counter = Arc::clone(&counter); let handle = thread::spawn(move || { let mut num = counter.lock().unwrap(); *num += 1; }); handles.push(handle); } for handle in handles { handle.join().unwrap(); } println!("Result: {}", *counter.lock().unwrap()); }
以上代码可以顺利运行:
Result: 10
内部可变性
在之前章节,我们提到过内部可变性,其中Rc<T>和RefCell<T>的结合,可以实现单线程的内部可变性。
现在我们又有了新的武器,由于Mutex<T>可以支持修改内部数据,当结合Arc<T>一起使用时,可以实现多线程的内部可变性。
简单总结下:Rc<T>/RefCell<T>用于单线程内部可变性, Arc<T>/Mutex<T>用于多线程内部可变性。
需要小心使用的 Mutex
如果有其它语言的编程经验,就知道互斥锁这家伙不好对付,想要正确使用,你得牢记在心:
- 在使用数据前必须先获取锁
- 在数据使用完成后,必须及时的释放锁,比如文章开头的例子,使用内部语句块的目的就是为了及时的释放锁
这两点看起来不起眼,但要正确的使用,其实是相当不简单的,对于其它语言,忘记释放锁是经常发生的,虽然 Rust 通过智能指针的drop机制帮助我们避免了这一点,但是由于不及时释放锁导致的性能问题也是常见的。
正因为这种困难性,导致很多用户都热衷于使用消息传递的方式来实现同步,例如 Go 语言直接把channel内置在语言特性中,甚至还有无锁的语言,例如erlang,完全使用Actor模型,依赖消息传递来完成共享和同步。幸好 Rust 的类型系统、所有权机制、智能指针等可以很好的帮助我们减轻使用锁时的负担。
另一个值的注意的是在使用Mutex<T>时,Rust 无法帮我们避免所有的逻辑错误,例如在之前章节,我们提到过使用Rc<T>可能会导致循环引用的问题。类似的,Mutex<T>也存在使用上的风险,例如创建死锁(deadlock):当一个操作试图锁住两个资源,然后两个线程各自获取其中一个锁,并试图获取另一个锁时,就会造成死锁。
死锁
在 Rust 中有多种方式可以创建死锁,了解这些方式有助于你提前规避可能的风险,一起来看看。
单线程死锁
这种死锁比较容易规避,但是当代码复杂后还是有可能遇到:
use std::sync::Mutex; fn main() { let data = Mutex::new(0); let d1 = data.lock(); let d2 = data.lock(); } // d1锁在此处释放
非常简单,只要你在另一个锁还未被释放时去申请新的锁,就会触发,当代码复杂后,这种情况可能就没有那么显眼。
多线程死锁
当我们拥有两个锁,且两个线程各自使用了其中一个锁,然后试图去访问另一个锁时,就可能发生死锁:
use std::{sync::{Mutex, MutexGuard}, thread}; use std::thread::sleep; use std::time::Duration; use lazy_static::lazy_static; lazy_static! { static ref MUTEX1: Mutex<i64> = Mutex::new(0); static ref MUTEX2: Mutex<i64> = Mutex::new(0); } fn main() { // 存放子线程的句柄 let mut children = vec![]; for i_thread in 0..2 { children.push(thread::spawn(move || { for _ in 0..1 { // 线程1 if i_thread % 2 == 0 { // 锁住MUTEX1 let guard: MutexGuard<i64> = MUTEX1.lock().unwrap(); println!("线程 {} 锁住了MUTEX1,接着准备去锁MUTEX2 !", i_thread); // 当前线程睡眠一小会儿,等待线程2锁住MUTEX2 sleep(Duration::from_millis(10)); // 去锁MUTEX2 let guard = MUTEX2.lock().unwrap(); // 线程2 } else { // 锁住MUTEX2 let _guard = MUTEX2.lock().unwrap(); println!("线程 {} 锁住了MUTEX2, 准备去锁MUTEX1", i_thread); let _guard = MUTEX1.lock().unwrap(); } } })); } // 等子线程完成 for child in children { let _ = child.join(); } println!("死锁没有发生"); }
在上面的描述中,我们用了"可能"二字,原因在于死锁在这段代码中不是必然发生的,总有一次运行你能看到最后一行打印输出。这是由于子线程的初始化顺序和执行速度并不确定,我们无法确定哪个线程中的锁先被执行,因此也无法确定两个线程对锁的具体使用顺序。
但是,可以简单的说明下死锁发生的必然条件:线程 1 锁住了MUTEX1并且线程2锁住了MUTEX2,然后线程 1 试图去访问MUTEX2,同时线程2试图去访问MUTEX1,就会死锁。 因为线程 2 需要等待线程 1 释放MUTEX1后,才会释放MUTEX2,而与此同时,线程 1 需要等待线程 2 释放MUTEX2后才能释放MUTEX1,这种情况造成了两个线程都无法释放对方需要的锁,最终死锁。
那么为何某些时候,死锁不会发生?原因很简单,线程 2 在线程 1 锁MUTEX1之前,就已经全部执行完了,随之线程 2 的MUTEX2和MUTEX1被全部释放,线程 1 对锁的获取将不再有竞争者。 同理,线程 1 若全部被执行完,那线程 2 也不会被锁,因此我们在线程 1 中间加一个睡眠,增加死锁发生的概率。如果你在线程 2 中同样的位置也增加一个睡眠,那死锁将必然发生!
try_lock
与lock方法不同,try_lock会尝试去获取一次锁,如果无法获取会返回一个错误,因此不会发生阻塞:
use std::{sync::{Mutex, MutexGuard}, thread}; use std::thread::sleep; use std::time::Duration; use lazy_static::lazy_static; lazy_static! { static ref MUTEX1: Mutex<i64> = Mutex::new(0); static ref MUTEX2: Mutex<i64> = Mutex::new(0); } fn main() { // 存放子线程的句柄 let mut children = vec![]; for i_thread in 0..2 { children.push(thread::spawn(move || { for _ in 0..1 { // 线程1 if i_thread % 2 == 0 { // 锁住MUTEX1 let guard: MutexGuard<i64> = MUTEX1.lock().unwrap(); println!("线程 {} 锁住了MUTEX1,接着准备去锁MUTEX2 !", i_thread); // 当前线程睡眠一小会儿,等待线程2锁住MUTEX2 sleep(Duration::from_millis(10)); // 去锁MUTEX2 let guard = MUTEX2.try_lock(); println!("线程1获取MUTEX2锁的结果: {:?}",guard); // 线程2 } else { // 锁住MUTEX2 let _guard = MUTEX2.lock().unwrap(); println!("线程 {} 锁住了MUTEX2, 准备去锁MUTEX1", i_thread); sleep(Duration::from_millis(10)); let guard = MUTEX1.try_lock(); println!("线程2获取MUTEX1锁的结果: {:?}",guard); } } })); } // 等子线程完成 for child in children { let _ = child.join(); } println!("死锁没有发生"); }
为了演示try_lock的作用,我们特定使用了之前必定会死锁的代码,并且将lock替换成try_lock,与之前的结果不同,这段代码将不会再有死锁发生:
线程 0 锁住了MUTEX1,接着准备去锁MUTEX2 !
线程 1 锁住了MUTEX2, 准备去锁MUTEX1
线程2获取MUTEX1锁的结果: Err("WouldBlock")
线程1获取MUTEX2锁的结果: Ok(0)
死锁没有发生
如上所示,当try_lock失败时,会报出一个错误:Err("WouldBlock"),接着线程中的剩余代码会继续执行,不会被阻塞。
一个有趣的命名规则:在 Rust 标准库中,使用
try_xxx都会尝试进行一次操作,如果无法完成,就立即返回,不会发生阻塞。例如消息传递章节中的try_recv以及本章节中的try_lock
读写锁 RwLock
Mutex会对每次读写都进行加锁,但某些时候,我们需要大量的并发读,Mutex就无法满足需求了,此时就可以使用RwLock:
use std::sync::RwLock; fn main() { let lock = RwLock::new(5); // 同一时间允许多个读 { let r1 = lock.read().unwrap(); let r2 = lock.read().unwrap(); assert_eq!(*r1, 5); assert_eq!(*r2, 5); } // 读锁在此处被drop // 同一时间只允许一个写 { let mut w = lock.write().unwrap(); *w += 1; assert_eq!(*w, 6); // 以下代码会panic,因为读和写不允许同时存在 // 写锁w直到该语句块结束才被释放,因此下面的读锁依然处于`w`的作用域中 // let r1 = lock.read(); // println!("{:?}",r1); }// 写锁在此处被drop }
RwLock在使用上和Mutex区别不大,需要注意的是,当读写同时发生时,程序会直接panic(本例是单线程,实际上多个线程中也是如此),因为会发生死锁:
thread 'main' panicked at 'rwlock read lock would result in deadlock', /rustc/efec545293b9263be9edfb283a7aa66350b3acbf/library/std/src/sys/unix/rwlock.rs:49:13
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
好在我们可以使用try_write和try_read来尝试进行一次写/读,若失败则返回错误:
Err("WouldBlock")
简单总结下RwLock:
- 同时允许多个读,但最多只能有一个写
- 读和写不能同时存在
- 读可以使用
read、try_read,写write、try_write, 在实际项目中,try_xxx会安全的多
Mutex 还是 RwLock
首先简单性上Mutex完胜,因为使用RwLock你得操心几个问题:
- 读和写不能同时发生,如果使用
try_xxx解决,就必须做大量的错误处理和失败重试机制 - 当读多写少时,写操作可能会因为一直无法获得锁导致连续多次失败(writer starvation)
- RwLock 其实是操作系统提供的,实现原理要比
Mutex复杂的多,因此单就锁的性能而言,比不上原生实现的Mutex
再来简单总结下两者的使用场景:
- 追求高并发读取时,使用
RwLock,因为Mutex一次只允许一个线程去读取 - 如果要保证写操作的成功性,使用
Mutex - 不知道哪个合适,统一使用
Mutex
需要注意的是,RwLock虽然看上去貌似提供了高并发读取的能力,但这个不能说明它的性能比Mutex高,事实上Mutex性能要好不少,后者唯一的问题也仅仅在于不能并发读取。
一个常见的、错误的使用RwLock的场景就是使用HashMap进行简单读写,因为HashMap的读和写都非常快,RwLock的复杂实现和相对低的性能反而会导致整体性能的降低,因此一般来说更适合使用Mutex。
总之,如果你要使用RwLock要确保满足以下两个条件:并发读,且需要对读到的资源进行"长时间"的操作,HashMap也许满足了并发读的需求,但是往往并不能满足后者:"长时间"的操作。
benchmark 永远是你在迷茫时最好的朋友!
三方库提供的锁实现
标准库在设计时总会存在取舍,因为往往性能并不是最好的,如果你追求性能,可以使用三方库提供的并发原语:
- parking_lot, 功能更完善、稳定,社区较为活跃,star 较多,更新较为活跃
- spin, 在多数场景中性能比
parking_lot高一点,最近没怎么更新
如果不是追求特别极致的性能,建议选择前者。
用条件变量(Condvar)控制线程的同步
Mutex用于解决资源安全访问的问题,但是我们还需要一个手段来解决资源访问顺序的问题。而 Rust 考虑到了这一点,为我们提供了条件变量(Condition Variables),它经常和Mutex一起使用,可以让线程挂起,直到某个条件发生后再继续执行,其实Condvar我们在之前的多线程章节就已经见到过,现在再来看一个不同的例子:
use std::sync::{Arc,Mutex,Condvar}; use std::thread::{spawn,sleep}; use std::time::Duration; fn main() { let flag = Arc::new(Mutex::new(false)); let cond = Arc::new(Condvar::new()); let cflag = flag.clone(); let ccond = cond.clone(); let hdl = spawn(move || { let mut m = { *cflag.lock().unwrap() }; let mut counter = 0; while counter < 3 { while !m { m = *ccond.wait(cflag.lock().unwrap()).unwrap(); } { m = false; *cflag.lock().unwrap() = false; } counter += 1; println!("inner counter: {}", counter); } }); let mut counter = 0; loop { sleep(Duration::from_millis(1000)); *flag.lock().unwrap() = true; counter += 1; if counter > 3 { break; } println!("outside counter: {}", counter); cond.notify_one(); } hdl.join().unwrap(); println!("{:?}", flag); }
例子中通过主线程来触发子线程实现交替打印输出:
outside counter: 1
inner counter: 1
outside counter: 2
inner counter: 2
outside counter: 3
inner counter: 3
Mutex { data: true, poisoned: false, .. }
信号量 Semaphore
在多线程中,另一个重要的概念就是信号量,使用它可以让我们精准的控制当前正在运行的任务最大数量。想象一下,当一个新游戏刚开服时(有些较火的老游戏也会,比如wow),往往会控制游戏内玩家的同时在线数,一旦超过某个临界值,就开始进行排队进服。而在实际使用中,也有很多时候,我们需要通过信号量来控制最大并发数,防止服务器资源被撑爆。
本来 Rust 在标准库中有提供一个信号量实现, 但是由于各种原因这个库现在已经不再推荐使用了,因此我们推荐使用tokio中提供的Semaphore实现: tokio::sync::Semaphore。
use std::sync::Arc; use tokio::sync::Semaphore; #[tokio::main] async fn main() { let semaphore = Arc::new(Semaphore::new(3)); let mut join_handles = Vec::new(); for _ in 0..5 { let permit = semaphore.clone().acquire_owned().await.unwrap(); join_handles.push(tokio::spawn(async move { // // 在这里执行任务... // drop(permit); })); } for handle in join_handles { handle.await.unwrap(); } }
上面代码创建了一个容量为 3 的信号量,当正在执行的任务超过 3 时,剩下的任务需要等待正在执行任务完成并减少信号量后到 3 以内时,才能继续执行。
这里的关键其实说白了就在于:信号量的申请和归还,使用前需要申请信号量,如果容量满了,就需要等待;使用后需要释放信号量,以便其它等待者可以继续。
总结
在很多时候,消息传递都是非常好用的手段,它可以让我们的数据在任务流水线上不断流转,实现起来非常优雅。
但是它并不能优雅的解决所有问题,因为我们面临的真实世界是非常复杂的,无法用某一种银弹统一解决。当面临消息传递不太适用的场景时,或者需要更好的性能和简洁性时,我们往往需要用锁来解决这些问题,因为锁允许多个线程同时访问同一个资源,简单粗暴。
除了锁之外,其实还有一种并发原语可以帮助我们解决并发访问数据的问题,那就是原子类型 Atomic,在下一章节中,我们会对其进行深入讲解。
线程同步:Atomic 原子类型与内存顺序
Mutex用起来简单,但是无法并发读,RwLock可以并发读,但是使用场景较为受限且性能不够,那么有没有一种全能性选手呢? 欢迎我们的Atomic闪亮登场。
从 Rust1.34 版本后,就正式支持原子类型。原子指的是一系列不可被 CPU 上下文交换的机器指令,这些指令组合在一起就形成了原子操作。在多核 CPU 下,当某个 CPU 核心开始运行原子操作时,会先暂停其它 CPU 内核对内存的操作,以保证原子操作不会被其它 CPU 内核所干扰。
由于原子操作是通过指令提供的支持,因此它的性能相比锁和消息传递会好很多。相比较于锁而言,原子类型不需要开发者处理加锁和释放锁的问题,同时支持修改,读取等操作,还具备较高的并发性能,几乎所有的语言都支持原子类型。
可以看出原子类型是无锁类型,但是无锁不代表无需等待,因为原子类型内部使用了CAS循环,当大量的冲突发生时,该等待还是得等待!但是总归比锁要好。
CAS 全称是 Compare and swap, 它通过一条指令读取指定的内存地址,然后判断其中的值是否等于给定的前置值,如果相等,则将其修改为新的值
使用 Atomic 作为全局变量
原子类型的一个常用场景,就是作为全局变量来使用:
use std::ops::Sub; use std::sync::atomic::{AtomicU64, Ordering}; use std::thread::{self, JoinHandle}; use std::time::Instant; const N_TIMES: u64 = 10000000; const N_THREADS: usize = 10; static R: AtomicU64 = AtomicU64::new(0); fn add_n_times(n: u64) -> JoinHandle<()> { thread::spawn(move || { for _ in 0..n { R.fetch_add(1, Ordering::Relaxed); } }) } fn main() { let s = Instant::now(); let mut threads = Vec::with_capacity(N_THREADS); for _ in 0..N_THREADS { threads.push(add_n_times(N_TIMES)); } for thread in threads { thread.join().unwrap(); } assert_eq!(N_TIMES * N_THREADS as u64, R.load(Ordering::Relaxed)); println!("{:?}",Instant::now().sub(s)); }
以上代码启动了数个线程,每个线程都在疯狂对全局变量进行加 1 操作, 最后将它与线程数 * 加1次数进行比较,如果发生了因为多个线程同时修改导致了脏数据,那么这两个必将不相等。好在,它没有让我们失望,不仅快速的完成了任务,而且保证了 100%的并发安全性。
当然以上代码的功能其实也可以通过Mutex来实现,但是后者的强大功能是建立在额外的性能损耗基础上的,因此性能会逊色不少:
Atomic实现:673ms
Mutex实现: 1136ms
可以看到Atomic实现会比Mutex快41%,实际上在复杂场景下还能更快(甚至达到 4 倍的性能差距)!
还有一点值得注意: 和Mutex一样,Atomic的值具有内部可变性,你无需将其声明为mut:
use std::sync::Mutex; use std::sync::atomic::{Ordering, AtomicU64}; struct Counter { count: u64 } fn main() { let n = Mutex::new(Counter { count: 0 }); n.lock().unwrap().count += 1; let n = AtomicU64::new(0); n.fetch_add(0, Ordering::Relaxed); }
这里有一个奇怪的枚举成员Ordering::Relaxed, 看上去很像是排序作用,但是我们并没有做排序操作啊?实际上它用于控制原子操作使用的内存顺序。
内存顺序
内存顺序是指 CPU 在访问内存时的顺序,该顺序可能受以下因素的影响:
- 代码中的先后顺序
- 编译器优化导致在编译阶段发生改变(内存重排序 reordering)
- 运行阶段因 CPU 的缓存机制导致顺序被打乱
编译器优化导致内存顺序的改变
对于第二点,我们举个例子:
static mut X: u64 = 0; static mut Y: u64 = 1; fn main() { ... // A unsafe { ... // B X = 1; ... // C Y = 3; ... // D X = 2; ... // E } }
假如在C和D代码片段中,根本没有用到X = 1,那么编译器很可能会将X = 1和X = 2进行合并:
#![allow(unused)] fn main() { ... // A unsafe { ... // B X = 2; ... // C Y = 3; ... // D ... // E } }
若代码A中创建了一个新的线程用于读取全局静态变量X,则该线程将无法读取到X = 1的结果,因为在编译阶段就已经被优化掉。
CPU 缓存导致的内存顺序的改变
假设之前的X = 1没有被优化掉,并且在代码片段A中有一个新的线程:
initial state: X = 0, Y = 1
THREAD Main THREAD A
X = 1; if X == 1 {
Y = 3; Y *= 2;
X = 2; }
我们来讨论下以上线程状态,Y最终的可能值(可能性依次降低):
Y = 3: 线程Main运行完后才运行线程A,或者线程A运行完后再运行线程MainY = 6: 线程Main的Y = 3运行完,但X = 2还没被运行, 此时线程 A 开始运行Y *= 2, 最后才运行Main线程的X = 2Y = 2: 线程Main正在运行Y = 3还没结束,此时线程A正在运行Y *= 2, 因此Y取到了值 1,然后Main的线程将Y设置为 3, 紧接着就被线程A的Y = 2所覆盖Y = 2: 上面的还只是一般的数据竞争,这里虽然产生了相同的结果2,但是背后的原理大相径庭: 线程Main运行完Y = 3,但是 CPU 缓存中的Y = 3还没有被同步到其它 CPU 缓存中,此时线程A中的Y *= 2就开始读取Y,结果读到了值1,最终计算出结果2
甚至更改成:
initial state: X = 0, Y = 1
THREAD Main THREAD A
X = 1; if X == 2 {
Y = 3; Y *= 2;
X = 2; }
还是可能出现Y = 2,因为Main线程中的X和Y被同步到其它 CPU 缓存中的顺序未必一致。
限定内存顺序的 5 个规则
在理解了内存顺序可能存在的改变后,你就可以明白为什么 Rust 提供了Ordering::Relaxed用于限定内存顺序了,事实上,该枚举有 5 个成员:
- Relaxed, 这是最宽松的规则,它对编译器和 CPU 不做任何限制,可以乱序
- Release 释放,设定内存屏障(Memory barrier),保证它之前的操作永远在它之前,但是它后面的操作可能被重排到它前面
- Acquire 获取, 设定内存屏障,保证在它之后的访问永远在它之后,但是它之前的操作却有可能被重排到它后面,往往和
Release在不同线程中联合使用 - AcqRel, 是 Acquire 和 Release 的结合,同时拥有它们俩提供的保证。比如你要对一个
atomic自增 1,同时希望该操作之前和之后的读取或写入操作不会被重新排序 - SeqCst 顺序一致性,
SeqCst就像是AcqRel的加强版,它不管原子操作是属于读取还是写入的操作,只要某个线程有用到SeqCst的原子操作,线程中该SeqCst操作前的数据操作绝对不会被重新排在该SeqCst操作之后,且该SeqCst操作后的数据操作也绝对不会被重新排在SeqCst操作前。
这些规则由于是系统提供的,因此其它语言提供的相应规则也大同小异,大家如果不明白可以看看其它语言的相关解释。
内存屏障的例子
下面我们以Release和Acquire为例,使用它们构筑出一对内存屏障,防止编译器和 CPU 将屏障前(Release)和屏障后(Acquire)中的数据操作重新排在屏障围成的范围之外:
use std::thread::{self, JoinHandle}; use std::sync::atomic::{Ordering, AtomicBool}; static mut DATA: u64 = 0; static READY: AtomicBool = AtomicBool::new(false); fn reset() { unsafe { DATA = 0; } READY.store(false, Ordering::Relaxed); } fn producer() -> JoinHandle<()> { thread::spawn(move || { unsafe { DATA = 100; // A } READY.store(true, Ordering::Release); // B: 内存屏障 ↑ }) } fn consumer() -> JoinHandle<()> { thread::spawn(move || { while !READY.load(Ordering::Acquire) {} // C: 内存屏障 ↓ assert_eq!(100, unsafe { DATA }); // D }) } fn main() { loop { reset(); let t_producer = producer(); let t_consumer = consumer(); t_producer.join().unwrap(); t_consumer.join().unwrap(); } }
原则上,Acquire用于读取,而Release用于写入。但是由于有些原子操作同时拥有读取和写入的功能,此时就需要使用AcqRel来设置内存顺序了。在内存屏障中被写入的数据,都可以被其它线程读取到,不会有 CPU 缓存的问题。
内存顺序的选择
- 不知道怎么选择时,优先使用
SeqCst,虽然会稍微减慢速度,但是慢一点也比出现错误好 - 多线程只计数
fetch_add而不使用该值触发其他逻辑分支的简单使用场景,可以使用Relaxed
参考 Which std::sync::atomic::Ordering to use?
多线程中使用 Atomic
在多线程环境中要使用Atomic需要配合Arc:
use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use std::{hint, thread}; fn main() { let spinlock = Arc::new(AtomicUsize::new(1)); let spinlock_clone = Arc::clone(&spinlock); let thread = thread::spawn(move|| { spinlock_clone.store(0, Ordering::SeqCst); }); // 等待其它线程释放锁 while spinlock.load(Ordering::SeqCst) != 0 { hint::spin_loop(); } if let Err(panic) = thread.join() { println!("Thread had an error: {:?}", panic); } }
Atomic 能替代锁吗
那么原子类型既然这么全能,它可以替代锁吗?答案是不行:
- 对于复杂的场景下,锁的使用简单粗暴,不容易有坑
std::sync::atomic包中仅提供了数值类型的原子操作:AtomicBool,AtomicIsize,AtomicUsize,AtomicI8,AtomicU16等,而锁可以应用于各种类型- 在有些情况下,必须使用锁来配合,例如上一章节中使用
Mutex配合Condvar
Atomic 的应用场景
事实上,Atomic虽然对于用户不太常用,但是对于高性能库的开发者、标准库开发者都非常常用,它是并发原语的基石,除此之外,还有一些场景适用:
- 无锁(lock free)数据结构
- 全局变量,例如全局自增 ID, 在后续章节会介绍
- 跨线程计数器,例如可以用于统计指标
以上列出的只是Atomic适用的部分场景,具体场景需要大家未来根据自己的需求进行权衡选择。
基于 Send 和 Sync 的线程安全
为何 Rc、RefCell 和裸指针不可以在多线程间使用?如何让裸指针可以在多线程使用?我们一起来探寻下这些问题的答案。
无法用于多线程的Rc
先来看一段多线程使用Rc的代码:
use std::thread; use std::rc::Rc; fn main() { let v = Rc::new(5); let t = thread::spawn(move || { println!("{}",v); }); t.join().unwrap(); }
以上代码将v的所有权通过move转移到子线程中,看似正确实则会报错:
error[E0277]: `Rc<i32>` cannot be sent between threads safely
------ 省略部分报错 --------
= help: within `[closure@src/main.rs:5:27: 7:6]`, the trait `Send` is not implemented for `Rc<i32>`
表面原因是Rc无法在线程间安全的转移,实际是编译器给予我们的那句帮助: the trait `Send` is not implemented for `Rc<i32>` (Rc<i32>未实现Send特征), 那么此处的Send特征又是何方神圣?
Rc 和 Arc 源码对比
在介绍Send特征之前,再来看看Arc为何可以在多线程使用,玄机在于两者的源码实现上:
#![allow(unused)] fn main() { // Rc源码片段 impl<T: ?Sized> !marker::Send for Rc<T> {} impl<T: ?Sized> !marker::Sync for Rc<T> {} // Arc源码片段 unsafe impl<T: ?Sized + Sync + Send> Send for Arc<T> {} unsafe impl<T: ?Sized + Sync + Send> Sync for Arc<T> {} }
!代表移除特征的相应实现,上面代码中Rc<T>的Send和Sync特征被特地移除了实现,而Arc<T>则相反,实现了Sync + Send,再结合之前的编译器报错,大概可以明白了:Send和Sync是在线程间安全使用一个值的关键。
Send 和 Sync
Send和Sync是 Rust 安全并发的重中之重,但是实际上它们只是标记特征(marker trait,该特征未定义任何行为,因此非常适合用于标记), 来看看它们的作用:
- 实现
Send的类型可以在线程间安全的传递其所有权 - 实现
Sync的类型可以在线程间安全的共享(通过引用)
这里还有一个潜在的依赖:一个类型要在线程间安全的共享的前提是,指向它的引用必须能在线程间传递。因为如果引用都不能被传递,我们就无法在多个线程间使用引用去访问同一个数据了。
由上可知,若类型 T 的引用&T是Send,则T是Sync。
没有例子的概念讲解都是耍流氓,来看看RwLock的实现:
#![allow(unused)] fn main() { unsafe impl<T: ?Sized + Send + Sync> Sync for RwLock<T> {} }
首先RwLock可以在线程间安全的共享,那它肯定是实现了Sync,但是我们的关注点不在这里。众所周知,RwLock可以并发的读,说明其中的值T必定也可以在线程间共享,那T必定要实现Sync。
果不其然,上述代码中,T的特征约束中就有一个Sync特征,那问题又来了,Mutex是不是相反?再来看看:
#![allow(unused)] fn main() { unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {} }
不出所料,Mutex<T>中的T并没有Sync特征约束。
武学秘籍再好,不见生死也是花拳绣腿。同样的,我们需要通过实战来彻底掌握Send和Sync,但在实战之前,先来简单看看有哪些类型实现了它们。
实现Send和Sync的类型
在 Rust 中,几乎所有类型都默认实现了Send和Sync,而且由于这两个特征都是可自动派生的特征(通过derive派生),意味着一个复合类型(例如结构体), 只要它内部的所有成员都实现了Send或者Sync,那么它就自动实现了Send或Sync。
正是因为以上规则,Rust 中绝大多数类型都实现了Send和Sync,除了以下几个(事实上不止这几个,只不过它们比较常见):
- 裸指针两者都没实现,因为它本身就没有任何安全保证
UnsafeCell不是Sync,因此Cell和RefCell也不是Rc两者都没实现(因为内部的引用计数器不是线程安全的)
当然,如果是自定义的复合类型,那没实现那哥俩的就较为常见了:只要复合类型中有一个成员不是Send或Sync,那么该复合类型也就不是Send或Sync。
手动实现 Send 和 Sync 是不安全的,通常并不需要手动实现 Send 和 Sync trait,实现者需要使用unsafe小心维护并发安全保证。
至此,相关的概念大家已经掌握,但是我敢肯定,对于这两个滑不溜秋的家伙,大家依然会非常模糊,不知道它们该如何使用。那么我们来一起看看如何让裸指针可以在线程间安全的使用。
为裸指针实现Send
上面我们提到裸指针既没实现Send,意味着下面代码会报错:
use std::thread; fn main() { let p = 5 as *mut u8; let t = thread::spawn(move || { println!("{:?}",p); }); t.join().unwrap(); }
报错跟之前无二: `*mut u8` cannot be sent between threads safely, 但是有一个问题,我们无法为其直接实现Send特征,好在可以用newtype类型 :struct MyBox(*mut u8);。
还记得之前的规则吗:复合类型中有一个成员没实现Send,该复合类型就不是Send,因此我们需要手动为它实现:
use std::thread; #[derive(Debug)] struct MyBox(*mut u8); unsafe impl Send for MyBox {} fn main() { let p = MyBox(5 as *mut u8); let t = thread::spawn(move || { println!("{:?}",p); }); t.join().unwrap(); }
此时,我们的指针已经可以欢快的在多线程间撒欢,以上代码很简单,但有一点需要注意:Send和Sync是unsafe特征,实现时需要用unsafe代码块包裹。
为裸指针实现Sync
由于Sync是多线程间共享一个值,大家可能会想这么实现:
use std::thread; fn main() { let v = 5; let t = thread::spawn(|| { println!("{:?}",&v); }); t.join().unwrap(); }
关于这种用法,在多线程章节也提到过,线程如果直接去借用其它线程的变量,会报错:closure may outlive the current function,, 原因在于编译器无法确定主线程main和子线程t谁的生命周期更长,特别是当两个线程都是子线程时,没有任何人知道哪个子线程会先结束,包括编译器!
因此我们得配合Arc去使用:
use std::thread; use std::sync::Arc; use std::sync::Mutex; #[derive(Debug)] struct MyBox(*const u8); unsafe impl Send for MyBox {} fn main() { let b = &MyBox(5 as *const u8); let v = Arc::new(Mutex::new(b)); let t = thread::spawn(move || { let _v1 = v.lock().unwrap(); }); t.join().unwrap(); }
上面代码将智能指针v的所有权转移给新线程,同时v包含了一个引用类型b,当在新的线程中试图获取内部的引用时,会报错:
error[E0277]: `*const u8` cannot be shared between threads safely
--> src/main.rs:25:13
|
25 | let t = thread::spawn(move || {
| ^^^^^^^^^^^^^ `*const u8` cannot be shared between threads safely
|
= help: within `MyBox`, the trait `Sync` is not implemented for `*const u8`
因为我们访问的引用实际上还是对主线程中的数据的借用,转移进来的仅仅是外层的智能指针引用。要解决很简单,为MyBox实现Sync:
#![allow(unused)] fn main() { unsafe impl Sync for MyBox {} }
总结
通过上面的两个裸指针的例子,我们了解了如何实现Send和Sync,以及如何只实现Send而不实现Sync,简单总结下:
- 实现
Send的类型可以在线程间安全的传递其所有权, 实现Sync的类型可以在线程间安全的共享(通过引用) - 绝大部分类型都实现了
Send和Sync,常见的未实现的有:裸指针、Cell、RefCell、Rc等 - 可以为自定义类型实现
Send和Sync,但是需要unsafe代码块 - 可以为部分 Rust 中的类型实现
Send、Sync,但是需要使用newtype,例如文中的裸指针例子
实践应用:多线程Web服务器 todo
全局变量
在一些场景,我们可能需要全局变量来简化状态共享的代码,包括全局 ID,全局数据存储等等,下面一起来看看有哪些创建全局变量的方法。
首先,有一点可以肯定,全局变量的生命周期肯定是'static,但是不代表它需要用static来声明,例如常量、字符串字面值等无需使用static进行声明,原因是它们已经被打包到二进制可执行文件中。
下面我们从编译期初始化及运行期初始化两个类别来介绍下全局变量有哪些类型及该如何使用。
编译期初始化
我们大多数使用的全局变量都只需要在编译期初始化即可,例如静态配置、计数器、状态值等等。
静态常量
全局常量可以在程序任何一部分使用,当然,如果它是定义在某个模块中,你需要引入对应的模块才能使用。常量,顾名思义它是不可变的,很适合用作静态配置:
const MAX_ID: usize = usize::MAX / 2; fn main() { println!("用户ID允许的最大值是{}",MAX_ID); }
常量与普通变量的区别
- 关键字是
const而不是let - 定义常量必须指明类型(如 i32)不能省略
- 定义常量时变量的命名规则一般是全部大写
- 常量可以在任意作用域进行定义,其生命周期贯穿整个程序的生命周期。编译时编译器会尽可能将其内联到代码中,所以在不同地方对同一常量的引用并不能保证引用到相同的内存地址
- 常量的赋值只能是常量表达式/数学表达式,也就是说必须是在编译期就能计算出的值,如果需要在运行时才能得出结果的值比如函数,则不能赋值给常量表达式
- 对于变量出现重复的定义(绑定)会发生变量遮盖,后面定义的变量会遮住前面定义的变量,常量则不允许出现重复的定义
静态变量
静态变量允许声明一个全局的变量,常用于全局数据统计,例如我们希望用一个变量来统计程序当前的总请求数:
static mut REQUEST_RECV: usize = 0; fn main() { unsafe { REQUEST_RECV += 1; assert_eq!(REQUEST_RECV, 1); } }
Rust 要求必须使用unsafe语句块才能访问和修改static变量,因为这种使用方式往往并不安全,其实编译器是对的,当在多线程中同时去修改时,会不可避免的遇到脏数据。
只有在同一线程内或者不在乎数据的准确性时,才应该使用全局静态变量。
和常量相同,定义静态变量的时候必须赋值为在编译期就可以计算出的值(常量表达式/数学表达式),不能是运行时才能计算出的值(如函数)
静态变量和常量的区别
- 静态变量不会被内联,在整个程序中,静态变量只有一个实例,所有的引用都会指向同一个地址
- 存储在静态变量中的值必须要实现 Sync trait
原子类型
想要全局计数器、状态控制等功能,又想要线程安全的实现,原子类型是非常好的办法。
use std::sync::atomic::{AtomicUsize, Ordering}; static REQUEST_RECV: AtomicUsize = AtomicUsize::new(0); fn main() { for _ in 0..100 { REQUEST_RECV.fetch_add(1, Ordering::Relaxed); } println!("当前用户请求数{:?}",REQUEST_RECV); }
关于原子类型的讲解看这篇文章
示例:全局 ID 生成器
来看看如何使用上面的内容实现一个全局 ID 生成器:
#![allow(unused)] fn main() { use std::sync::atomic::{Ordering, AtomicUsize}; struct Factory{ factory_id: usize, } static GLOBAL_ID_COUNTER: AtomicUsize = AtomicUsize::new(0); const MAX_ID: usize = usize::MAX / 2; fn generate_id()->usize{ // 检查两次溢出,否则直接加一可能导致溢出 let current_val = GLOBAL_ID_COUNTER.load(Ordering::Relaxed); if current_val > MAX_ID{ panic!("Factory ids overflowed"); } let next_id = GLOBAL_ID_COUNTER.fetch_add(1, Ordering::Relaxed); if next_id > MAX_ID{ panic!("Factory ids overflowed"); } next_id } impl Factory{ fn new()->Self{ Self{ factory_id: generate_id() } } } }
运行期初始化
以上的静态初始化有一个致命的问题:无法用函数进行静态初始化,例如你如果想声明一个全局的Mutex锁:
use std::sync::Mutex; static NAMES: Mutex<String> = Mutex::new(String::from("Sunface, Jack, Allen")); fn main() { let v = NAMES.lock().unwrap(); println!("{}",v); }
运行后报错如下:
error[E0015]: calls in statics are limited to constant functions, tuple structs and tuple variants
--> src/main.rs:3:42
|
3 | static NAMES: Mutex<String> = Mutex::new(String::from("sunface"));
但你又必须在声明时就对NAMES进行初始化,此时就陷入了两难的境地。好在天无绝人之路,我们可以使用lazy_static包来解决这个问题。
lazy_static
lazy_static是社区提供的非常强大的宏,用于懒初始化静态变量,之前的静态变量都是在编译期初始化的,因此无法使用函数调用进行赋值,而lazy_static允许我们在运行期初始化静态变量!
use std::sync::Mutex; use lazy_static::lazy_static; lazy_static! { static ref NAMES: Mutex<String> = Mutex::new(String::from("Sunface, Jack, Allen")); } fn main() { let mut v = NAMES.lock().unwrap(); v.push_str(", Myth"); println!("{}",v); }
当然,使用lazy_static在每次访问静态变量时,会有轻微的性能损失,因为其内部实现用了一个底层的并发原语std::sync::Once,在每次访问该变量时,程序都会执行一次原子指令用于确认静态变量的初始化是否完成。
lazy_static宏,匹配的是static ref,所以定义的静态变量都是不可变引用
可能有读者会问,为何需要在运行期初始化一个静态变量,除了上面的全局锁,你会遇到最常见的场景就是:一个全局的动态配置,它在程序开始后,才加载数据进行初始化,最终可以让各个线程直接访问使用
再来看一个使用lazy_static实现全局缓存的例子:
use lazy_static::lazy_static; use std::collections::HashMap; lazy_static! { static ref HASHMAP: HashMap<u32, &'static str> = { let mut m = HashMap::new(); m.insert(0, "foo"); m.insert(1, "bar"); m.insert(2, "baz"); m }; } fn main() { // 首次访问`HASHMAP`的同时对其进行初始化 println!("The entry for `0` is \"{}\".", HASHMAP.get(&0).unwrap()); // 后续的访问仅仅获取值,再不会进行任何初始化操作 println!("The entry for `1` is \"{}\".", HASHMAP.get(&1).unwrap()); }
需要注意的是,lazy_static直到运行到main中的第一行代码时,才进行初始化,非常lazy static。
Box::leak
在Box智能指针章节中,我们提到了Box::leak可以用于全局变量,例如用作运行期初始化的全局动态配置,先来看看如果不使用lazy_static也不使用Box::leak,会发生什么:
#[derive(Debug)] struct Config { a: String, b: String, } static mut CONFIG: Option<&mut Config> = None; fn main() { unsafe { CONFIG = Some(&mut Config { a: "A".to_string(), b: "B".to_string(), }); println!("{:?}", CONFIG) } }
以上代码我们声明了一个全局动态配置CONFIG,并且其值初始化为None,然后在程序开始运行后,给它赋予相应的值,运行后报错:
error[E0716]: temporary value dropped while borrowed
--> src/main.rs:10:28
|
10 | CONFIG = Some(&mut Config {
| _________-__________________^
| |_________|
| ||
11 | || a: "A".to_string(),
12 | || b: "B".to_string(),
13 | || });
| || ^-- temporary value is freed at the end of this statement
| ||_________||
| |_________|assignment requires that borrow lasts for `'static`
| creates a temporary which is freed while still in use
可以看到,Rust 的借用和生命周期规则限制了我们做到这一点,因为试图将一个局部生命周期的变量赋值给全局生命周期的CONFIG,这明显是不安全的。
好在Rust为我们提供了Box::leak方法,它可以将一个变量从内存中泄漏(听上去怪怪的,竟然做主动内存泄漏),然后将其变为'static生命周期,最终该变量将和程序活得一样久,因此可以赋值给全局静态变量CONFIG。
#[derive(Debug)] struct Config { a: String, b: String } static mut CONFIG: Option<&mut Config> = None; fn main() { let c = Box::new(Config { a: "A".to_string(), b: "B".to_string(), }); unsafe { // 将`c`从内存中泄漏,变成`'static`生命周期 CONFIG = Some(Box::leak(c)); println!("{:?}", CONFIG); } }
从函数中返回全局变量
问题又来了,如果我们需要在运行期,从一个函数返回一个全局变量该如何做?例如:
#[derive(Debug)] struct Config { a: String, b: String, } static mut CONFIG: Option<&mut Config> = None; fn init() -> Option<&'static mut Config> { Some(&mut Config { a: "A".to_string(), b: "B".to_string(), }) } fn main() { unsafe { CONFIG = init(); println!("{:?}", CONFIG) } }
报错这里就不展示了,跟之前大同小异,还是生命周期引起的,那么该如何解决呢?依然可以用Box::leak:
#[derive(Debug)] struct Config { a: String, b: String, } static mut CONFIG: Option<&mut Config> = None; fn init() -> Option<&'static mut Config> { let c = Box::new(Config { a: "A".to_string(), b: "B".to_string(), }); Some(Box::leak(c)) } fn main() { unsafe { CONFIG = init(); println!("{:?}", CONFIG) } }
标准库中的 OnceCell
在 Rust 标准库中提供 lazy::OnceCell 和 lazy::SyncOnceCell 两种 Cell,前者用于单线程,后者用于多线程,它们用来存储堆上的信息,并且具有最多只能赋值一次的特性。 如实现一个多线程的日志组件 Logger:
#![feature(once_cell)] use std::{lazy::SyncOnceCell, thread}; fn main() { // 子线程中调用 let handle = thread::spawn(|| { let logger = Logger::global(); logger.log("thread message".to_string()); }); // 主线程调用 let logger = Logger::global(); logger.log("some message".to_string()); let logger2 = Logger::global(); logger2.log("other message".to_string()); handle.join().unwrap(); } #[derive(Debug)] struct Logger; static LOGGER: SyncOnceCell<Logger> = SyncOnceCell::new(); impl Logger { fn global() -> &'static Logger { // 获取或初始化 Logger LOGGER.get_or_init(|| { println!("Logger is being created..."); // 初始化打印 Logger }) } fn log(&self, message: String) { println!("{}", message) } }
以上代码我们声明了一个 global() 关联函数,并在其内部调用 get_or_init 进行初始化 Logger,之后在不同线程上多次调用 Logger::global() 获取其实例:
Logger is being created...
some message
other message
thread message
可以看到,Logger is being created... 在多个线程中使用也只被打印了一次。
特别注意,目前 OnceCell 和 SyncOnceCell API 暂未稳定,需启用特性 #![feature(once_cell)]。
总结
在 Rust 中有很多方式可以创建一个全局变量,本章也只是介绍了其中一部分,更多的还等待大家自己去挖掘学习(当然,未来可能本章节会不断完善,最后变成一个巨无霸- , -)。
简单来说,全局变量可以分为两种:
- 编译期初始化的全局变量,
const创建常量,static创建静态变量,Atomic创建原子类型 - 运行期初始化的全局变量,
lazy_static用于懒初始化,Box::leak利用内存泄漏将一个变量的生命周期变为'static
错误处理
在之前的返回值和错误处理章节中,我们学习了几个重要的概念,例如 Result 用于返回结果处理,? 用于错误的传播,若大家对此还较为模糊,强烈建议回头温习下。
在本章节中一起来看看如何对 Result ( Option ) 做进一步的处理,以及如何定义自己的错误类型。
组合器
在设计模式中,有一个组合器模式,相信有 Java 背景的同学对此并不陌生。
将对象组合成树形结构以表示“部分整体”的层次结构。组合模式使得用户对单个对象和组合对象的使用具有一致性。–GoF <<设计模式>>
与组合器模式有所不同,在 Rust 中,组合器更多的是用于对返回结果的类型进行变换:例如使用 ok_or 将一个 Option 类型转换成 Result 类型。
下面我们来看看一些常见的组合器。
or() 和 and()
跟布尔关系的与/或很像,这两个方法会对两个表达式做逻辑组合,最终返回 Option / Result。
or(),表达式按照顺序求值,若任何一个表达式的结果是Some或Ok,则该值会立刻返回and(),若两个表达式的结果都是Some或Ok,则第二个表达式中的值被返回。若任何一个的结果是None或Err,则立刻返回。
实际上,只要将布尔表达式的 true / false,替换成 Some / None 或 Ok / Err 就很好理解了。
fn main() { let s1 = Some("some1"); let s2 = Some("some2"); let n: Option<&str> = None; let o1: Result<&str, &str> = Ok("ok1"); let o2: Result<&str, &str> = Ok("ok2"); let e1: Result<&str, &str> = Err("error1"); let e2: Result<&str, &str> = Err("error2"); assert_eq!(s1.or(s2), s1); // Some1 or Some2 = Some1 assert_eq!(s1.or(n), s1); // Some or None = Some assert_eq!(n.or(s1), s1); // None or Some = Some assert_eq!(n.or(n), n); // None1 or None2 = None2 assert_eq!(o1.or(o2), o1); // Ok1 or Ok2 = Ok1 assert_eq!(o1.or(e1), o1); // Ok or Err = Ok assert_eq!(e1.or(o1), o1); // Err or Ok = Ok assert_eq!(e1.or(e2), e2); // Err1 or Err2 = Err2 assert_eq!(s1.and(s2), s2); // Some1 and Some2 = Some2 assert_eq!(s1.and(n), n); // Some and None = None assert_eq!(n.and(s1), n); // None and Some = None assert_eq!(n.and(n), n); // None1 and None2 = None1 assert_eq!(o1.and(o2), o2); // Ok1 and Ok2 = Ok2 assert_eq!(o1.and(e1), e1); // Ok and Err = Err assert_eq!(e1.and(o1), e1); // Err and Ok = Err assert_eq!(e1.and(e2), e1); // Err1 and Err2 = Err1 }
除了 or 和 and 之外,Rust 还为我们提供了 xor ,但是它只能应用在 Option 上,其实想想也是这个理,如果能应用在 Result 上,那你又该如何对一个值和错误进行异或操作?
or_else() 和 and_then()
它们跟 or() 和 and() 类似,唯一的区别在于,它们的第二个表达式是一个闭包。
fn main() { // or_else with Option let s1 = Some("some1"); let s2 = Some("some2"); let fn_some = || Some("some2"); // 类似于: let fn_some = || -> Option<&str> { Some("some2") }; let n: Option<&str> = None; let fn_none = || None; assert_eq!(s1.or_else(fn_some), s1); // Some1 or_else Some2 = Some1 assert_eq!(s1.or_else(fn_none), s1); // Some or_else None = Some assert_eq!(n.or_else(fn_some), s2); // None or_else Some = Some assert_eq!(n.or_else(fn_none), None); // None1 or_else None2 = None2 // or_else with Result let o1: Result<&str, &str> = Ok("ok1"); let o2: Result<&str, &str> = Ok("ok2"); let fn_ok = |_| Ok("ok2"); // 类似于: let fn_ok = |_| -> Result<&str, &str> { Ok("ok2") }; let e1: Result<&str, &str> = Err("error1"); let e2: Result<&str, &str> = Err("error2"); let fn_err = |_| Err("error2"); assert_eq!(o1.or_else(fn_ok), o1); // Ok1 or_else Ok2 = Ok1 assert_eq!(o1.or_else(fn_err), o1); // Ok or_else Err = Ok assert_eq!(e1.or_else(fn_ok), o2); // Err or_else Ok = Ok assert_eq!(e1.or_else(fn_err), e2); // Err1 or_else Err2 = Err2 }
fn main() { // and_then with Option let s1 = Some("some1"); let s2 = Some("some2"); let fn_some = |_| Some("some2"); // 类似于: let fn_some = |_| -> Option<&str> { Some("some2") }; let n: Option<&str> = None; let fn_none = |_| None; assert_eq!(s1.and_then(fn_some), s2); // Some1 and_then Some2 = Some2 assert_eq!(s1.and_then(fn_none), n); // Some and_then None = None assert_eq!(n.and_then(fn_some), n); // None and_then Some = None assert_eq!(n.and_then(fn_none), n); // None1 and_then None2 = None1 // and_then with Result let o1: Result<&str, &str> = Ok("ok1"); let o2: Result<&str, &str> = Ok("ok2"); let fn_ok = |_| Ok("ok2"); // 类似于: let fn_ok = |_| -> Result<&str, &str> { Ok("ok2") }; let e1: Result<&str, &str> = Err("error1"); let e2: Result<&str, &str> = Err("error2"); let fn_err = |_| Err("error2"); assert_eq!(o1.and_then(fn_ok), o2); // Ok1 and_then Ok2 = Ok2 assert_eq!(o1.and_then(fn_err), e2); // Ok and_then Err = Err assert_eq!(e1.and_then(fn_ok), e1); // Err and_then Ok = Err assert_eq!(e1.and_then(fn_err), e1); // Err1 and_then Err2 = Err1 }
filter
filter 用于对 Option 进行过滤:
fn main() { let s1 = Some(3); let s2 = Some(6); let n = None; let fn_is_even = |x: &i8| x % 2 == 0; assert_eq!(s1.filter(fn_is_even), n); // Some(3) -> 3 is not even -> None assert_eq!(s2.filter(fn_is_even), s2); // Some(6) -> 6 is even -> Some(6) assert_eq!(n.filter(fn_is_even), n); // None -> no value -> None }
map() 和 map_err()
map 可以将 Some 或 Ok 中的值映射为另一个:
fn main() { let s1 = Some("abcde"); let s2 = Some(5); let n1: Option<&str> = None; let n2: Option<usize> = None; let o1: Result<&str, &str> = Ok("abcde"); let o2: Result<usize, &str> = Ok(5); let e1: Result<&str, &str> = Err("abcde"); let e2: Result<usize, &str> = Err("abcde"); let fn_character_count = |s: &str| s.chars().count(); assert_eq!(s1.map(fn_character_count), s2); // Some1 map = Some2 assert_eq!(n1.map(fn_character_count), n2); // None1 map = None2 assert_eq!(o1.map(fn_character_count), o2); // Ok1 map = Ok2 assert_eq!(e1.map(fn_character_count), e2); // Err1 map = Err2 }
但是如果你想要将 Err 中的值进行改变, map 就无能为力了,此时我们需要用 map_err:
fn main() { let o1: Result<&str, &str> = Ok("abcde"); let o2: Result<&str, isize> = Ok("abcde"); let e1: Result<&str, &str> = Err("404"); let e2: Result<&str, isize> = Err(404); let fn_character_count = |s: &str| -> isize { s.parse().unwrap() }; // 该函数返回一个 isize assert_eq!(o1.map_err(fn_character_count), o2); // Ok1 map = Ok2 assert_eq!(e1.map_err(fn_character_count), e2); // Err1 map = Err2 }
通过对 o1 的操作可以看出,与 map 面对 Err 时的短小类似, map_err 面对 Ok 时也是相当无力的。
map_or() 和 map_or_else()
map_or 在 map 的基础上提供了一个默认值:
fn main() { const V_DEFAULT: u32 = 1; let s: Result<u32, ()> = Ok(10); let n: Option<u32> = None; let fn_closure = |v: u32| v + 2; assert_eq!(s.map_or(V_DEFAULT, fn_closure), 12); assert_eq!(n.map_or(V_DEFAULT, fn_closure), V_DEFAULT); }
如上所示,当处理 None 的时候,V_DEFAULT 作为默认值被直接返回。
map_or_else 与 map_or 类似,但是它是通过一个闭包来提供默认值:
fn main() { let s = Some(10); let n: Option<i8> = None; let fn_closure = |v: i8| v + 2; let fn_default = || 1; assert_eq!(s.map_or_else(fn_default, fn_closure), 12); assert_eq!(n.map_or_else(fn_default, fn_closure), 1); let o = Ok(10); let e = Err(5); let fn_default_for_result = |v: i8| v + 1; // 闭包可以对 Err 中的值进行处理,并返回一个新值 assert_eq!(o.map_or_else(fn_default_for_result, fn_closure), 12); assert_eq!(e.map_or_else(fn_default_for_result, fn_closure), 6); }
ok_or() and ok_or_else()
这两兄弟可以将 Option 类型转换为 Result 类型。其中 ok_or 接收一个默认的 Err 参数:
fn main() { const ERR_DEFAULT: &str = "error message"; let s = Some("abcde"); let n: Option<&str> = None; let o: Result<&str, &str> = Ok("abcde"); let e: Result<&str, &str> = Err(ERR_DEFAULT); assert_eq!(s.ok_or(ERR_DEFAULT), o); // Some(T) -> Ok(T) assert_eq!(n.ok_or(ERR_DEFAULT), e); // None -> Err(default) }
而 ok_or_else 接收一个闭包作为 Err 参数:
fn main() { let s = Some("abcde"); let n: Option<&str> = None; let fn_err_message = || "error message"; let o: Result<&str, &str> = Ok("abcde"); let e: Result<&str, &str> = Err("error message"); assert_eq!(s.ok_or_else(fn_err_message), o); // Some(T) -> Ok(T) assert_eq!(n.ok_or_else(fn_err_message), e); // None -> Err(default) }
以上列出的只是常用的一部分,强烈建议大家看看标准库中有哪些可用的 API,在实际项目中,这些 API 将会非常有用: Option 和 Result。
自定义错误类型
虽然标准库定义了大量的错误类型,但是一个严谨的项目,光使用这些错误类型往往是不够的,例如我们可能会为暴露给用户的错误定义相应的类型。
为了帮助我们更好的定义错误,Rust 在标准库中提供了一些可复用的特征,例如 std::error::Error 特征:
#![allow(unused)] fn main() { use std::fmt::{Debug, Display}; pub trait Error: Debug + Display { fn source(&self) -> Option<&(Error + 'static)> { ... } } }
当自定义类型实现该特征后,该类型就可以作为 Err 来使用,下面一起来看看。
实际上,自定义错误类型只需要实现
Debug和Display特征即可,source方法是可选的,而Debug特征往往也无需手动实现,可以直接通过derive来派生
最简单的错误
use std::fmt; // AppError 是自定义错误类型,它可以是当前包中定义的任何类型,在这里为了简化,我们使用了单元结构体作为例子。 // 为 AppError 自动派生 Debug 特征 #[derive(Debug)] struct AppError; // 为 AppError 实现 std::fmt::Display 特征 impl fmt::Display for AppError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "An Error Occurred, Please Try Again!") // user-facing output } } // 一个示例函数用于产生 AppError 错误 fn produce_error() -> Result<(), AppError> { Err(AppError) } fn main(){ match produce_error() { Err(e) => eprintln!("{}", e), _ => println!("No error"), } eprintln!("{:?}", produce_error()); // Err({ file: src/main.rs, line: 17 }) }
上面的例子很简单,我们定义了一个错误类型,当为它派生了 Debug 特征,同时手动实现了 Display 特征后,该错误类型就可以作为 Err来使用了。
事实上,实现 Debug 和 Display 特征并不是作为 Err 使用的必要条件,大家可以把这两个特征实现和相应使用去除,然后看看代码会否报错。既然如此,我们为何要为自定义类型实现这两个特征呢?原因有二:
- 错误得打印输出后,才能有实际用处,而打印输出就需要实现这两个特征
- 可以将自定义错误转换成
Box<dyn std::error:Error>特征对象,在后面的归一化不同错误类型部分,我们会详细介绍
更详尽的错误
上一个例子中定义的错误非常简单,我们无法从错误中得到更多的信息,现在再来定义一个具有错误码和信息的错误:
use std::fmt; struct AppError { code: usize, message: String, } // 根据错误码显示不同的错误信息 impl fmt::Display for AppError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let err_msg = match self.code { 404 => "Sorry, Can not find the Page!", _ => "Sorry, something is wrong! Please Try Again!", }; write!(f, "{}", err_msg) } } impl fmt::Debug for AppError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "AppError {{ code: {}, message: {} }}", self.code, self.message ) } } fn produce_error() -> Result<(), AppError> { Err(AppError { code: 404, message: String::from("Page not found"), }) } fn main() { match produce_error() { Err(e) => eprintln!("{}", e), // 抱歉,未找到指定的页面! _ => println!("No error"), } eprintln!("{:?}", produce_error()); // Err(AppError { code: 404, message: Page not found }) eprintln!("{:#?}", produce_error()); // Err( // AppError { code: 404, message: Page not found } // ) }
在本例中,我们除了增加了错误码和消息外,还手动实现了 Debug 特征,原因在于,我们希望能自定义 Debug 的输出内容,而不是使用派生后系统提供的默认输出形式。
错误转换 From 特征
标准库、三方库、本地库,各有各的精彩,各也有各的错误。那么问题就来了,我们该如何将其它的错误类型转换成自定义的错误类型?总不能神鬼牛魔,同台共舞吧。。
好在 Rust 为我们提供了 std::convert::From 特征:
#![allow(unused)] fn main() { pub trait From<T>: Sized { fn from(_: T) -> Self; } }
事实上,该特征在之前的
?操作符章节中就有所介绍。大家都使用过
String::from函数吧?它可以通过&str来创建一个String,其实该函数就是From特征提供的
下面一起来看看如何为自定义类型实现 From 特征:
use std::fs::File; use std::io; #[derive(Debug)] struct AppError { kind: String, // 错误类型 message: String, // 错误信息 } // 为 AppError 实现 std::convert::From 特征,由于 From 包含在 std::prelude 中,因此可以直接简化引入。 // 实现 From<io::Error> 意味着我们可以将 io::Error 错误转换成自定义的 AppError 错误 impl From<io::Error> for AppError { fn from(error: io::Error) -> Self { AppError { kind: String::from("io"), message: error.to_string(), } } } fn main() -> Result<(), AppError> { let _file = File::open("nonexistent_file.txt")?; Ok(()) } // --------------- 上述代码运行后输出 --------------- Error: AppError { kind: "io", message: "No such file or directory (os error 2)" }
上面的代码中除了实现 From 外,还有一点特别重要,那就是 ? 可以将错误进行隐式的强制转换:File::open 返回的是 std::io::Error, 我们并没有进行任何显式的转换,它就能自动变成 AppError ,这就是 ? 的强大之处!
上面的例子只有一个标准库错误,再来看看多个不同的错误转换成 AppError 的实现:
use std::fs::File; use std::io::{self, Read}; use std::num; #[derive(Debug)] struct AppError { kind: String, message: String, } impl From<io::Error> for AppError { fn from(error: io::Error) -> Self { AppError { kind: String::from("io"), message: error.to_string(), } } } impl From<num::ParseIntError> for AppError { fn from(error: num::ParseIntError) -> Self { AppError { kind: String::from("parse"), message: error.to_string(), } } } fn main() -> Result<(), AppError> { let mut file = File::open("hello_world.txt")?; let mut content = String::new(); file.read_to_string(&mut content)?; let _number: usize; _number = content.parse()?; Ok(()) } // --------------- 上述代码运行后的可能输出 --------------- // 01. 若 hello_world.txt 文件不存在 Error: AppError { kind: "io", message: "No such file or directory (os error 2)" } // 02. 若用户没有相关的权限访问 hello_world.txt Error: AppError { kind: "io", message: "Permission denied (os error 13)" } // 03. 若 hello_world.txt 包含有非数字的内容,例如 Hello, world! Error: AppError { kind: "parse", message: "invalid digit found in string" }
归一化不同的错误类型
至此,关于 Rust 的错误处理大家已经了若指掌了,下面再来看看一些实战中的问题。
在实际项目中,我们往往会为不同的错误定义不同的类型,这样做非常好,但是如果你要在一个函数中返回不同的错误呢?例如:
use std::fs::read_to_string; fn main() -> Result<(), std::io::Error> { let html = render()?; println!("{}", html); Ok(()) } fn render() -> Result<String, std::io::Error> { let file = std::env::var("MARKDOWN")?; let source = read_to_string(file)?; Ok(source) }
上面的代码会报错,原因在于 render 函数中的两个 ? 返回的实际上是不同的错误:env::var() 返回的是 std::env::VarError,而 read_to_string 返回的是 std::io::Error。
为了满足 render 函数的签名,我们就需要将 env::VarError 和 io::Error 归一化为同一种错误类型。要实现这个目的有三种方式:
- 使用特征对象
Box<dyn Error> - 自定义错误类型
- 使用
thiserror
下面依次来看看相关的解决方式。
Box<dyn Error>
大家还记得我们之前提到的 std::error::Error 特征吧,当时有说:自定义类型实现 Debug + Display 特征的主要原因就是为了能转换成 Error 的特征对象,而特征对象恰恰是在同一个地方使用不同类型的关键:
use std::fs::read_to_string; use std::error::Error; fn main() -> Result<(), Box<dyn Error>> { let html = render()?; println!("{}", html); Ok(()) } fn render() -> Result<String, Box<dyn Error>> { let file = std::env::var("MARKDOWN")?; let source = read_to_string(file)?; Ok(source) }
这个方法很简单,在绝大多数场景中,性能也非常够用,但是有一个问题:Result 实际上不会限制错误的类型,也就是一个类型就算不实现 Error 特征,它依然可以在 Result<T, E> 中作为 E 来使用,此时这种特征对象的解决方案就无能为力了。
自定义错误类型
与特征对象相比,自定义错误类型麻烦归麻烦,但是它非常灵活,因此也不具有上面的类似限制:
use std::fs::read_to_string; fn main() -> Result<(), MyError> { let html = render()?; println!("{}", html); Ok(()) } fn render() -> Result<String, MyError> { let file = std::env::var("MARKDOWN")?; let source = read_to_string(file)?; Ok(source) } #[derive(Debug)] enum MyError { EnvironmentVariableNotFound, IOError(std::io::Error), } impl From<std::env::VarError> for MyError { fn from(_: std::env::VarError) -> Self { Self::EnvironmentVariableNotFound } } impl From<std::io::Error> for MyError { fn from(value: std::io::Error) -> Self { Self::IOError(value) } } impl std::error::Error for MyError {} impl std::fmt::Display for MyError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { MyError::EnvironmentVariableNotFound => write!(f, "Environment variable not found"), MyError::IOError(err) => write!(f, "IO Error: {}", err.to_string()), } } }
上面代码中有一行值得注意:impl std::error::Error for MyError {} ,只有为自定义错误类型实现 Error 特征后,才能转换成相应的特征对象。
不得不说,真是啰嗦啊。因此在能用特征对象的时候,建议大家还是使用特征对象,无论如何,代码可读性还是很重要的!
上面的第二种方式灵活归灵活,啰嗦也是真啰嗦,好在 Rust 的社区为我们提供了 thiserror 解决方案,下面一起来看看该如何简化 Rust 中的错误处理。
简化错误处理
对于开发者而言,错误处理是代码中打交道最多的部分之一,因此选择一把趁手的武器也很重要,它可以帮助我们节省大量的时间和精力,好钢应该用在代码逻辑而不是冗长的错误处理上。
thiserror
thiserror可以帮助我们简化上面的第二种解决方案:
use std::fs::read_to_string; fn main() -> Result<(), MyError> { let html = render()?; println!("{}", html); Ok(()) } fn render() -> Result<String, MyError> { let file = std::env::var("MARKDOWN")?; let source = read_to_string(file)?; Ok(source) } #[derive(thiserror::Error, Debug)] enum MyError { #[error("Environment variable not found")] EnvironmentVariableNotFound(#[from] std::env::VarError), #[error(transparent)] IOError(#[from] std::io::Error), }
如上所示,只要简单写写注释,就可以实现错误处理了,惊不惊喜?
error-chain
error-chain 也是简单好用的库,可惜不再维护了,但是我觉得它依然可以在合适的地方大放光彩,值得大家去了解下。
use std::fs::read_to_string; error_chain::error_chain! { foreign_links { EnvironmentVariableNotFound(::std::env::VarError); IOError(::std::io::Error); } } fn main() -> Result<()> { let html = render()?; println!("{}", html); Ok(()) } fn render() -> Result<String> { let file = std::env::var("MARKDOWN")?; let source = read_to_string(file)?; Ok(source) }
喏,简单吧?使用 error-chain 的宏你可以获得:Error 结构体,错误类型 ErrorKind 枚举 以及一个自定义的 Result 类型。
anyhow
anyhow 和 thiserror 是同一个作者开发的,这里是作者关于 anyhow 和 thiserror 的原话:
如果你想要设计自己的错误类型,同时给调用者提供具体的信息时,就使用
thiserror,例如当你在开发一个三方库代码时。如果你只想要简单,就使用anyhow,例如在自己的应用服务中。
use std::fs::read_to_string; use anyhow::Result; fn main() -> Result<()> { let html = render()?; println!("{}", html); Ok(()) } fn render() -> Result<String> { let file = std::env::var("MARKDOWN")?; let source = read_to_string(file)?; Ok(source) }
关于如何选用 thiserror 和 anyhow 只需要遵循一个原则即可:是否关注自定义错误消息,关注则使用 thiserror(常见业务代码),否则使用 anyhow(编写第三方库代码)。
总结
Rust 一个为人津津乐道的点就是强大、易用的错误处理,对于新手来说,这个机制可能会有些复杂,但是一旦体会到了其中的好处,你将跟我一样沉醉其中不能自拔。
unsafe 简介
圣人论迹不论心,论心世上无圣人,对于编程语言而言,亦是如此。
虽然在本章之前,我们学到的代码都是在编译期就得到了 Rust 的安全保障,但是在其内心深处也隐藏了一些阴暗面,在这些阴暗面里,内存安全就存在一些变数了:当不娴熟的开发者接触到这些阴暗面,就可能写出不安全的代码,因此我们称这种代码为 unsafe 代码块。
为何会有 unsafe
几乎每个语言都有 unsafe 关键字,但 Rust 语言使用 unsafe 的原因可能与其它编程语言还有所不同。
过强的编译器
说来尴尬,unsafe 的存在主要是因为 Rust 的静态检查太强了,但是强就算了,它还很保守,这就会导致当编译器在分析代码时,一些正确代码会因为编译器无法分析出它的所有正确性,结果将这段代码拒绝,导致编译错误。
这种保守的选择确实也没有错,毕竟安全就是要防微杜渐,但是对于使用者来说,就不是那么愉快的事了,特别是当配合 Rust 的所有权系统一起使用时,有个别问题是真的棘手和难以解决。
举个例子,在之前的自引用章节中,我们就提到了相关的编译检查是很难绕过的,如果想要绕过,最常用的方法之一就是使用 unsafe 和 Pin。
好在,当遇到这些情况时,我们可以使用 unsafe 来解决。此时,你需要替代编译器的部分职责对 unsafe 代码的正确性负责,例如在正常代码中不可能遇到的空指针解引用问题在 unsafe 中就可能会遇到,我们需要自己来处理好这些类似的问题。
特定任务的需要
至于 unsafe 存在的另一个原因就是:它必须要存在。原因是计算机底层的一些硬件就是不安全的,如果 Rust 只允许你做安全的操作,那一些任务就无法完成,换句话说,我们还怎么跟 C++ 干架?
Rust 的一个主要定位就是系统编程,众所周知,系统编程就是底层编程,往往需要直接跟操作系统打交道,甚至于去实现一个操作系统。而为了实现底层系统编程,unsafe 就是必不可少的。
在了解了为何会有 unsafe 后,我们再来看看,除了这些必要性,unsafe 还能给我们带来哪些超能力。
unsafe 的超能力
使用 unsafe 非常简单,只需要将对应的代码块标记下即可:
fn main() { let mut num = 5; let r1 = &num as *const i32; unsafe { println!("r1 is: {}", *r1); } }
上面代码中, r1 是一个裸指针(raw pointer),由于它具有破坏 Rust 内存安全的潜力,因此只能在 unsafe 代码块中使用,如果你去掉 unsafe {},编译器会立刻报错。
言归正传, unsafe 能赋予我们 5 种超能力,这些能力在安全的 Rust 代码中是无法获取的:
- 解引用裸指针,就如上例所示
- 调用一个
unsafe或外部的函数 - 访问或修改一个可变的静态变量
- 实现一个
unsafe特征 - 访问
union中的字段
在本章中,我们将着重讲解裸指针和 FFI 的使用。
unsafe 的安全保证
曾经在 reddit 上有一个讨论还挺热闹的,是关于 unsafe 的命名是否合适,总之公有公理,婆有婆理,但有一点是不可否认的:虽然名称自带不安全,但是 Rust 依然提供了强大的安全支撑。
首先,unsafe 并不能绕过 Rust 的借用检查,也不能关闭任何 Rust 的安全检查规则,例如当你在 unsafe 中使用引用时,该有的检查一样都不会少。
因此 unsafe 能给大家提供的也仅仅是之前的 5 种超能力,在使用这 5 种能力时,编译器才不会进行内存安全方面的检查,最典型的就是使用裸指针(引用和裸指针有很大的区别)。
谈虎色变?
在网上充斥着这样的言论:千万不要使用 unsafe,因为它不安全,甚至有些库会以没有 unsafe 代码作为噱头来吸引用户。事实上,大可不必,如果按照这个标准,Rust 的标准库也将不复存在!
Rust 中的 unsafe 其实没有那么可怕,虽然听上去很不安全,但是实际上 Rust 依然提供了很多机制来帮我们提升了安全性,因此不必像对待 Go 语言的 unsafe 那样去畏惧于使用 Rust 中的 unsafe 。
大致使用原则总结如下:没必要用时,就不要用,当有必要用时,就大胆用,但是尽量控制好边界,让 unsafe 的范围尽可能小。
控制 unsafe 的使用边界
unsafe 不安全,但是该用的时候就要用,在一些时候,它能帮助我们大幅降低代码实现的成本。
而作为使用者,你的水平决定了 unsafe 到底有多不安全,因此你需要在 unsafe 中小心谨慎地去访问内存。
即使做到小心谨慎,依然会有出错的可能性,但是 unsafe 语句块决定了:就算内存访问出错了,你也能立刻意识到,错误是在 unsafe 代码块中,而不花大量时间像无头苍蝇一样去寻找问题所在。
正因为此,写代码时要尽量控制好 unsafe 的边界大小,越小的 unsafe 越会让我们在未来感谢自己当初的选择。
除了控制边界大小,另一个很常用的方式就是在 unsafe 代码块外包裹一层 safe 的 API,例如一个函数声明为 safe 的,然后在其内部有一块儿是 unsafe 代码。
忍不住抱怨一句,内存安全方面的 bug ,是真心难查!
五种兵器
古龙有一部小说,名为"七种兵器",其中每一种都精妙绝伦,令人闻风丧胆,而 unsafe 也有五种兵器,它们可以让你拥有其它代码无法实现的能力,同时它们也像七种兵器一样令人闻风丧胆,下面一起来看看庐山真面目。
解引用裸指针
裸指针(raw pointer,又称原生指针) 在功能上跟引用类似,同时它也需要显式地注明可变性。但是又和引用有所不同,裸指针长这样: *const T 和 *mut T,它们分别代表了不可变和可变。
大家在之前学过 * 操作符,知道它可以用于解引用,但是在裸指针 *const T 中,这里的 * 只是类型名称的一部分,并没有解引用的含义。
至此,我们已经学过三种类似指针的概念:引用、智能指针和裸指针。与前两者不同,裸指针:
- 可以绕过 Rust 的借用规则,可以同时拥有一个数据的可变、不可变指针,甚至还能拥有多个可变的指针
- 并不能保证指向合法的内存
- 可以是
null - 没有实现任何自动的回收 (drop)
总之,裸指针跟 C 指针是非常像的,使用它需要以牺牲安全性为前提,但我们获得了更好的性能,也可以跟其它语言或硬件打交道。
基于引用创建裸指针
下面的代码基于值的引用同时创建了可变和不可变的裸指针:
#![allow(unused)] fn main() { let mut num = 5; let r1 = &num as *const i32; let r2 = &mut num as *mut i32; }
as 可以用于强制类型转换,在之前章节中有讲解。在这里,我们将引用 &num / &mut num 强转为相应的裸指针 *const i32 / *mut i32。
细心的同学可能会发现,在这段代码中并没有 unsafe 的身影,原因在于:创建裸指针是安全的行为,而解引用裸指针才是不安全的行为 :
fn main() { let mut num = 5; let r1 = &num as *const i32; unsafe { println!("r1 is: {}", *r1); } }
基于内存地址创建裸指针
在上面例子中,我们基于现有的引用来创建裸指针,这种行为是很安全的。但是接下来的方式就不安全了:
#![allow(unused)] fn main() { let address = 0x012345usize; let r = address as *const i32; }
这里基于一个内存地址来创建裸指针,可以想像,这种行为是相当危险的。试图使用任意的内存地址往往是一种未定义的行为(undefined behavior),因为该内存地址有可能存在值,也有可能没有,就算有值,也大概率不是你需要的值。
同时编译器也有可能会优化这段代码,会造成没有任何内存访问发生,甚至程序还可能发生段错误(segmentation fault)。总之,你几乎没有好的理由像上面这样实现代码,虽然它是可行的。
如果真的要使用内存地址,也是类似下面的用法,先取地址,再使用,而不是凭空捏造一个地址:
use std::{slice::from_raw_parts, str::from_utf8_unchecked}; // 获取字符串的内存地址和长度 fn get_memory_location() -> (usize, usize) { let string = "Hello World!"; let pointer = string.as_ptr() as usize; let length = string.len(); (pointer, length) } // 在指定的内存地址读取字符串 fn get_str_at_location(pointer: usize, length: usize) -> &'static str { unsafe { from_utf8_unchecked(from_raw_parts(pointer as *const u8, length)) } } fn main() { let (pointer, length) = get_memory_location(); let message = get_str_at_location(pointer, length); println!( "The {} bytes at 0x{:X} stored: {}", length, pointer, message ); // 如果大家想知道为何处理裸指针需要 `unsafe`,可以试着反注释以下代码 // let message = get_str_at_location(1000, 10); }
以上代码同时还演示了访问非法内存地址会发生什么,大家可以试着去反注释这段代码试试。
使用 * 解引用
#![allow(unused)] fn main() { let a = 1; let b: *const i32 = &a as *const i32; let c: *const i32 = &a; unsafe { println!("{}", *c); } }
使用 * 可以对裸指针进行解引用,由于该指针的内存安全性并没有任何保证,因此我们需要使用 unsafe 来包裹解引用的逻辑(切记,unsafe 语句块的范围一定要尽可能的小,具体原因在上一章节有讲)。
以上代码另一个值得注意的点就是:除了使用 as 来显式的转换,我们还使用了隐式的转换方式 let c: *const i32 = &a;。在实际使用中,我们建议使用 as 来转换,因为这种显式的方式更有助于提醒用户:你在使用的指针是裸指针,需要小心。
基于智能指针创建裸指针
还有一种创建裸指针的方式,那就是基于智能指针来创建:
#![allow(unused)] fn main() { let a: Box<i32> = Box::new(10); // 需要先解引用a let b: *const i32 = &*a; // 使用 into_raw 来创建 let c: *const i32 = Box::into_raw(a); }
小结
像之前代码演示的那样,使用裸指针可以让我们创建两个可变指针都指向同一个数据,如果使用安全的 Rust,你是无法做到这一点的,违背了借用规则,编译器会对我们进行无情的阻止。因此裸指针可以绕过借用规则,但是由此带来的数据竞争问题,就需要大家自己来处理了,总之,需要小心!
既然这么危险,为何还要使用裸指针?除了之前提到的性能等原因,还有一个重要用途就是跟 C 语言的代码进行交互( FFI ),在讲解 FFI 之前,先来看看如何调用 unsafe 函数或方法。
调用 unsafe 函数或方法
unsafe 函数从外表上来看跟普通函数并无区别,唯一的区别就是它需要使用 unsafe fn 来进行定义。这种定义方式是为了告诉调用者:当调用此函数时,你需要注意它的相关需求,因为 Rust 无法担保调用者在使用该函数时能满足它所需的一切需求。
强制调用者加上 unsafe 语句块,就可以让他清晰的认识到,正在调用一个不安全的函数,需要小心看看文档,看看函数有哪些特别的要求需要被满足。
unsafe fn dangerous() {} fn main() { dangerous(); }
如果试图像上面这样调用,编译器就会报错:
error[E0133]: call to unsafe function is unsafe and requires unsafe function or block
--> src/main.rs:3:5
|
3 | dangerous();
| ^^^^^^^^^^^ call to unsafe function
按照报错提示,加上 unsafe 语句块后,就能顺利执行了:
#![allow(unused)] fn main() { unsafe { dangerous(); } }
道理很简单,但一定要牢记在心:使用 unsafe 声明的函数时,一定要看看相关的文档,确定自己没有遗漏什么。
还有,unsafe 无需俄罗斯套娃,在 unsafe 函数体中使用 unsafe 语句块是多余的行为。
用安全抽象包裹 unsafe 代码
一个函数包含了 unsafe 代码不代表我们需要将整个函数都定义为 unsafe fn。事实上,在标准库中有大量的安全函数,它们内部都包含了 unsafe 代码块,下面我们一起来看看一个很好用的标准库函数:split_at_mut。
大家可以想象一下这个场景:需要将一个数组分成两个切片,且每一个切片都要求是可变的。类似需求在安全 Rust 中是很难实现的,因为要对同一个数组做两个可变借用:
fn split_at_mut(slice: &mut [i32], mid: usize) -> (&mut [i32], &mut [i32]) { let len = slice.len(); assert!(mid <= len); (&mut slice[..mid], &mut slice[mid..]) } fn main() { let mut v = vec![1, 2, 3, 4, 5, 6]; let r = &mut v[..]; let (a, b) = split_at_mut(r, 3); assert_eq!(a, &mut [1, 2, 3]); assert_eq!(b, &mut [4, 5, 6]); }
上面代码一眼看过去就知道会报错,因为我们试图在自定义的 split_at_mut 函数中,可变借用 slice 两次:
error[E0499]: cannot borrow `*slice` as mutable more than once at a time
--> src/main.rs:6:30
|
1 | fn split_at_mut(slice: &mut [i32], mid: usize) -> (&mut [i32], &mut [i32]) {
| - let's call the lifetime of this reference `'1`
...
6 | (&mut slice[..mid], &mut slice[mid..])
| -------------------------^^^^^--------
| | | |
| | | second mutable borrow occurs here
| | first mutable borrow occurs here
| returning this value requires that `*slice` is borrowed for `'1`
对于 Rust 的借用检查器来说,它无法理解我们是分别借用了同一个切片的两个不同部分,但事实上,这种行为是没任何问题的,毕竟两个借用没有任何重叠之处。总之,不太聪明的 Rust 编译器阻碍了我们用这种简单且安全的方式去实现,那只能剑走偏锋,试试 unsafe 了。
use std::slice; fn split_at_mut(slice: &mut [i32], mid: usize) -> (&mut [i32], &mut [i32]) { let len = slice.len(); let ptr = slice.as_mut_ptr(); assert!(mid <= len); unsafe { ( slice::from_raw_parts_mut(ptr, mid), slice::from_raw_parts_mut(ptr.add(mid), len - mid), ) } } fn main() { let mut v = vec![1, 2, 3, 4, 5, 6]; let r = &mut v[..]; let (a, b) = split_at_mut(r, 3); assert_eq!(a, &mut [1, 2, 3]); assert_eq!(b, &mut [4, 5, 6]); }
相比安全实现,这段代码就显得没那么好理解了,甚至于我们还需要像 C 语言那样,通过指针地址的偏移去控制数组的分割。
as_mut_ptr会返回指向slice首地址的裸指针*mut i32slice::from_raw_parts_mut函数通过指针和长度来创建一个新的切片,简单来说,该切片的初始地址是ptr,长度为midptr.add(mid)可以获取第二个切片的初始地址,由于切片中的元素是i32类型,每个元素都占用了 4 个字节的内存大小,因此我们不能简单的用ptr + mid来作为初始地址,而应该使用ptr + 4 * mid,但是这种使用方式并不安全,因此.add方法是最佳选择
由于 slice::from_raw_parts_mut 使用裸指针作为参数,因此它是一个 unsafe fn,我们在使用它时,就必须用 unsafe 语句块进行包裹,类似的,.add 方法也是如此(还是那句话,不要将无关的代码包含在 unsafe 语句块中)。
部分同学可能会有疑问,那这段代码我们怎么保证 unsafe 中使用的裸指针 ptr 和 ptr.add(mid) 是合法的呢?秘诀就在于 assert!(mid <= len); ,通过这个断言,我们保证了裸指针一定指向了 slice 切片中的某个元素,而不是一个莫名其妙的内存地址。
再回到我们的主题:虽然 split_at_mut 使用了 unsafe,但我们无需将其声明为 unsafe fn,这种情况下就是使用安全的抽象包裹 unsafe 代码,这里的 unsafe 使用是非常安全的,因为我们从合法数据中创建了的合法指针。
与之对比,下面的代码就非常危险了:
#![allow(unused)] fn main() { use std::slice; let address = 0x01234usize; let r = address as *mut i32; let slice: &[i32] = unsafe { slice::from_raw_parts_mut(r, 10000) }; println!("{:?}",slice); }
这段代码从一个任意的内存地址,创建了一个 10000 长度的 i32 切片,我们无法保证切片中的元素都是合法的 i32 值,这种访问就是一种未定义行为(UB = undefined behavior)。
zsh: segmentation fault
不出所料,运行后看到了一个段错误。
FFI
FFI(Foreign Function Interface)可以用来与其它语言进行交互,但是并不是所有语言都这么称呼,例如 Java 称之为 JNI(Java Native Interface)。
FFI 之所以存在是由于现实中很多代码库都是由不同语言编写的,如果我们需要使用某个库,但是它是由其它语言编写的,那么往往只有两个选择:
- 对该库进行重写或者移植
- 使用
FFI
前者相当不错,但是在很多时候,并没有那么多时间去重写,因此 FFI 就成了最佳选择。回到 Rust 语言上,由于这门语言依然很年轻,一些生态是缺失的,我们在写一些不是那么大众的项目时,可能会同时遇到没有相应的 Rust 库可用的尴尬境况,此时通过 FFI 去调用 C 语言的库就成了相当棒的选择。
还有在将 C/C++ 的代码重构为 Rust 时,先将相关代码引入到 Rust 项目中,然后逐步重构,也是不错的(为什么用不错来形容?因为重构一个有一定规模的 C/C++ 项目远没有想象中美好,因此最好的选择还是对于新项目使用 Rust 实现,老项目。。就让它先运行着吧)。
当然,除了 FFI 还有一个办法可以解决跨语言调用的问题,那就是将其作为一个独立的服务,然后使用网络调用的方式去访问,HTTP,gRPC 都可以。
言归正传,之前我们提到 unsafe 的另一个重要目的就是对 FFI 提供支持,它的全称是 Foreign Function Interface,顾名思义,通过 FFI , 我们的 Rust 代码可以跟其它语言的外部代码进行交互。
下面的例子演示了如何调用 C 标准库中的 abs 函数:
extern "C" { fn abs(input: i32) -> i32; } fn main() { unsafe { println!("Absolute value of -3 according to C: {}", abs(-3)); } }
C 语言的代码定义在了 extern 代码块中, 而 extern 必须使用 unsafe 才能进行进行调用,原因在于其它语言的代码并不会强制执行 Rust 的规则,因此 Rust 无法对这些代码进行检查,最终还是要靠开发者自己来保证代码的正确性和程序的安全性。
ABI
在 extern "C" 代码块中,我们列出了想要调用的外部函数的签名。其中 "C" 定义了外部函数所使用的应用二进制接口ABI (Application Binary Interface):ABI 定义了如何在汇编层面来调用该函数。在所有 ABI 中,C 语言的是最常见的。
在其它语言中调用 Rust 函数
在 Rust 中调用其它语言的函数是让 Rust 利用其他语言的生态,那反过来可以吗?其他语言可以利用 Rust 的生态不?答案是肯定的。
我们可以使用 extern 来创建一个接口,其它语言可以通过该接口来调用相关的 Rust 函数。但是此处的语法与之前有所不同,之前用的是语句块,而这里是在函数定义时加上 extern 关键字,当然,别忘了指定相应的 ABI:
#![allow(unused)] fn main() { #[no_mangle] pub extern "C" fn call_from_c() { println!("Just called a Rust function from C!"); } }
上面的代码可以让 call_from_c 函数被 C 语言的代码调用,当然,前提是将其编译成一个共享库,然后链接到 C 语言中。
这里还有一个比较奇怪的注解 #[no_mangle],它用于告诉 Rust 编译器:不要乱改函数的名称。 Mangling 的定义是:当 Rust 因为编译需要去修改函数的名称,例如为了让名称包含更多的信息,这样其它的编译部分就能从该名称获取相应的信息,这种修改会导致函数名变得相当不可读。
因此,为了让 Rust 函数能顺利被其它语言调用,我们必须要禁止掉该功能。
访问或修改一个可变的静态变量
这部分我们在之前的全局变量章节中有过详细介绍,这里就不再赘述,大家可以前往此章节阅读。
实现 unsafe 特征
说实话,unsafe 的特征确实不多见,如果大家还记得的话,我们在之前的 Send 和 Sync 章节中实现过 unsafe 特征 Send。
之所以会有 unsafe 的特征,是因为该特征至少有一个方法包含有编译器无法验证的内容。unsafe 特征的声明很简单:
unsafe trait Foo { // 方法列表 } unsafe impl Foo for i32 { // 实现相应的方法 } fn main() {}
通过 unsafe impl 的使用,我们告诉编译器:相应的正确性由我们自己来保证。
再回到刚提到的 Send 特征,若我们的类型中的所有字段都实现了 Send 特征,那该类型也会自动实现 Send。但是如果我们想要为某个类型手动实现 Send ,例如为裸指针,那么就必须使用 unsafe,相关的代码在之前的链接中也有,大家可以移步查看。
总之,Send 特征标记为 unsafe 是因为 Rust 无法验证我们的类型是否能在线程间安全的传递,因此就需要通过 unsafe 来告诉编译器,它无需操心,剩下的交给我们自己来处理。
访问 union 中的字段
截止目前,我们还没有介绍过 union ,原因很简单,它主要用于跟 C 代码进行交互。
访问 union 的字段是不安全的,因为 Rust 无法保证当前存储在 union 实例中的数据类型。
#![allow(unused)] fn main() { #[repr(C)] union MyUnion { f1: u32, f2: f32, } }
上从可以看出,union 的使用方式跟结构体确实很相似,但是前者的所有字段都共享同一个存储空间,意味着往 union 的某个字段写入值,会导致其它字段的值会被覆盖。
关于 union 的更多信息,可以在这里查看。
一些实用工具(库)
由于 unsafe 和 FFI 在 Rust 的使用场景中是相当常见的(例如相对于 Go 的 unsafe 来说),因此社区已经开发出了相当一部分实用的工具,可以改善相应的开发体验。
rust-bindgen 和 cbindgen
对于 FFI 调用来说,保证接口的正确性是非常重要的,这两个库可以帮我们自动生成相应的接口,其中 rust-bindgen 用于在 Rust 中访问 C 代码,而 cbindgen则反之。
下面以 rust-bindgen 为例,来看看如何自动生成调用 C 的代码,首先下面是 C 代码:
typedef struct Doggo {
int many;
char wow;
} Doggo;
void eleven_out_of_ten_majestic_af(Doggo* pupper);
下面是自动生成的可以调用上面代码的 Rust 代码:
#![allow(unused)] fn main() { /* automatically generated by rust-bindgen 0.99.9 */ #[repr(C)] pub struct Doggo { pub many: ::std::os::raw::c_int, pub wow: ::std::os::raw::c_char, } extern "C" { pub fn eleven_out_of_ten_majestic_af(pupper: *mut Doggo); } }
cxx
如果需要跟 C++ 代码交互,非常推荐使用 cxx,它提供了双向的调用,最大的优点就是安全:是的,你无需通过 unsafe 来使用它!
Miri
miri 可以生成 Rust 的中间层表示 MIR,对于编译器来说,我们的 Rust 代码首先会被编译为 MIR ,然后再提交给 LLVM 进行处理。
可以通过 rustup component add miri 来安装它,并通过 cargo miri 来使用,同时还可以使用 cargo miri test 来运行测试代码。
miri 可以帮助我们检查常见的未定义行为(UB = Undefined Behavior),以下列出了一部分:
- 内存越界检查和内存释放后再使用(use-after-free)
- 使用未初始化的数据
- 数据竞争
- 内存对齐问题
但是需要注意的是,它只能帮助识别被执行代码路径的风险,那些未被执行到的代码是没办法被识别的。
Clippy
官方的 clippy 检查器提供了有限的 unsafe 支持,虽然不多,但是至少有一定帮助。例如 missing_safety_docs 检查可以帮助我们检查哪些 unsafe 函数遗漏了文档。
需要注意的是: Rust 编译器并不会默认开启所有检查,大家可以调用 rustc -W help 来看看最新的信息。
Prusti
prusti 需要大家自己来构建一个证明,然后通过它证明代码中的不变量是正确被使用的,当你在安全代码中使用不安全的不变量时,就会非常有用。具体的使用文档见这里。
模糊测试(fuzz testing)
在 Rust Fuzz Book 中列出了一些 Rust 可以使用的模糊测试方法。
同时,我们还可以使用 rutenspitz 这个过程宏来测试有状态的代码,例如数据结构。
总结
至此,unsafe 的五种兵器已介绍完毕,大家是否意犹未尽?我想说的是,就算意犹未尽,也没有其它兵器了。
就像上一章中所提到的,unsafe 只应该用于这五种场景,其它场景,你应该坚决的使用安全的代码,否则就会像 actix-web 的前作者一样,被很多人议论,甚至被喷。。。
总之,能不使用 unsafe 一定不要使用,就算使用也要控制好边界,让范围尽可能的小,就像本章的例子一样,只有真的需要 unsafe 的代码,才应该包含其中, 而不是将无关代码也纳入进来。
进一步学习
内联汇编
Macro 宏编程
在编程世界可以说是谈“宏”色变,原因在于 C 语言中的宏是非常危险的东东,但并不是所有语言都像 C 这样,例如对于古老的语言 Lisp 来说,宏就是就是一个非常强大的好帮手。
那话说回来,在 Rust 中宏到底是好是坏呢?本章将带你揭开它的神秘面纱。
事实上,我们虽然没有见过宏,但是已经多次用过它,例如在全书的第一个例子中就用到了:println!("你好,世界"),这里 println! 就是一个最常用的宏,可以看到它和函数最大的区别是:它在调用时多了一个 !,除此之外还有 vec! 、assert_eq! 都是相当常用的,可以说宏在 Rust 中无处不在。
细心的读者可能会注意到 println! 后面跟着的是 (),而 vec! 后面跟着的是 [],这是因为宏的参数可以使用 ()、[] 以及 {}:
fn main() { println!("aaaa"); println!["aaaa"]; println!{"aaaa"} }
虽然三种使用形式皆可,但是 Rust 内置的宏都有自己约定俗成的使用方式,例如 vec![...]、assert_eq!(...) 等。
在 Rust 中宏分为两大类:声明式宏( declarative macros ) macro_rules! 和三种过程宏( procedural macros ):
#[derive],在之前多次见到的派生宏,可以为目标结构体或枚举派生指定的代码,例如Debug特征- 类属性宏(Attribute-like macro),用于为目标添加自定义的属性
- 类函数宏(Function-like macro),看上去就像是函数调用
如果感觉难以理解,也不必担心,接下来我们将逐个看看它们的庐山真面目,在此之前,先来看下为何需要宏,特别是 Rust 的函数明明已经很强大了。
宏和函数的区别
宏和函数的区别并不少,而且对于宏擅长的领域,函数其实是有些无能为力的。
元编程
从根本上来说,宏是通过一种代码来生成另一种代码,如果大家熟悉元编程,就会发现两者的共同点。
在附录 D中讲到的 derive 属性,就会自动为结构体派生出相应特征所需的代码,例如 #[derive(Debug)],还有熟悉的 println! 和 vec!,所有的这些宏都会展开成相应的代码,且很可能是长得多的代码。
总之,元编程可以帮我们减少所需编写的代码,也可以一定程度上减少维护的成本,虽然函数复用也有类似的作用,但是宏依然拥有自己独特的优势。
可变参数
Rust 的函数签名是固定的:定义了两个参数,就必须传入两个参数,多一个少一个都不行,对于从 JS/TS 过来的同学,这一点其实是有些恼人的。
而宏就可以拥有可变数量的参数,例如可以调用一个参数的 println!("hello"),也可以调用两个参数的 println!("hello {}", name)。
宏展开
由于宏会被展开成其它代码,且这个展开过程是发生在编译器对代码进行解释之前。因此,宏可以为指定的类型实现某个特征:先将宏展开成实现特征的代码后,再被编译。
而函数就做不到这一点,因为它直到运行时才能被调用,而特征需要在编译期被实现。
宏的缺点
相对函数来说,由于宏是基于代码再展开成代码,因此实现相比函数来说会更加复杂,再加上宏的语法更为复杂,最终导致定义宏的代码相当地难读,也难以理解和维护。
声明式宏 macro_rules!
在 Rust 中使用最广的就是声明式宏,它们也有一些其它的称呼,例如示例宏( macros by example )、macro_rules! 或干脆直接称呼为宏。
声明式宏允许我们写出类似 match 的代码。match 表达式是一个控制结构,其接收一个表达式,然后将表达式的结果与多个模式进行匹配,一旦匹配了某个模式,则该模式相关联的代码将被执行:
#![allow(unused)] fn main() { match target { 模式1 => 表达式1, 模式2 => { 语句1; 语句2; 表达式2 }, _ => 表达式3 } }
而宏也是将一个值跟对应的模式进行匹配,且该模式会与特定的代码相关联。但是与 match 不同的是,宏里的值是一段 Rust 源代码(字面量),模式用于跟这段源代码的结构相比较,一旦匹配,传入宏的那段源代码将被模式关联的代码所替换,最终实现宏展开。值得注意的是,所有的这些都是在编译期发生,并没有运行期的性能损耗。
简化版的 vec!
在动态数组 Vector 章节中,我们学习了使用 vec! 来便捷的初始化一个动态数组:
#![allow(unused)] fn main() { let v: Vec<u32> = vec![1, 2, 3]; }
最重要的是,通过 vec! 创建的动态数组支持任何元素类型,也并没有限制数组的长度,如果使用函数,我们是无法做到这一点的。
好在我们有 macro_rules!,来看看该如何使用它来实现 vec!,以下是一个简化实现:
#![allow(unused)] fn main() { #[macro_export] macro_rules! vec { ( $( $x:expr ),* ) => { { let mut temp_vec = Vec::new(); $( temp_vec.push($x); )* temp_vec } }; } }
简化实现版本?这也太难了吧!!只能说,欢迎来到宏的世界,在这里你能见到优雅 Rust 的另一面:) 标准库中的 vec! 还包含了预分配内存空间的代码,如果引入进来,那大家将更难以接受。
#[macro_export] 注释将宏进行了导出,这样其它的包就可以将该宏引入到当前作用域中,然后才能使用。可能有同学会提问:我们在使用标准库 vec! 时也没有引入宏啊,那是因为 Rust 已经通过 std::prelude 的方式为我们自动引入了。
紧接着,就使用 macro_rules! 进行了宏定义,需要注意的是宏的名称是 vec,而不是 vec!,后者的感叹号只在调用时才需要。
vec 的定义结构跟 match 表达式很像,但这里我们只有一个分支,其中包含一个模式 ( $( $x:expr ),* ),跟模式相关联的代码就在 => 之后。一旦模式成功匹配,那这段相关联的代码就会替换传入的源代码。
由于 vec 宏只有一个模式,因此它只能匹配一种源代码,其它类型的都将导致报错,而更复杂的宏往往会拥有更多的分支。
虽然宏和 match 都称之为模式,但是前者跟后者的模式规则是不同的。如果大家想要更深入的了解宏的模式,可以查看这里。
模式解析
而现在,我们先来简单讲解下 ( $( $x:expr ),* ) 的含义。
首先,我们使用圆括号 () 将整个宏模式包裹其中。紧随其后的是 $(),跟括号中模式相匹配的值(传入的 Rust 源代码)会被捕获,然后用于代码替换。在这里,模式 $x:expr 会匹配任何 Rust 表达式并给予该模式一个名称:$x。
$() 之后的逗号说明在 $() 所匹配的代码的后面会有一个可选的逗号分隔符,紧随逗号之后的 * 说明 * 之前的模式会被匹配零次或任意多次(类似正则表达式)。
当我们使用 vec![1, 2, 3] 来调用该宏时,$x 模式将被匹配三次,分别是 1、2、3。为了帮助大家巩固,我们再来一起过一下:
$()中包含的是模式$x:expr,该模式中的expr表示会匹配任何 Rust 表达式,并给予该模式一个名称$x- 因此
$x模式可以跟整数1进行匹配,也可以跟字符串 "hello" 进行匹配:vec!["hello", "world"] $()之后的逗号,意味着1和2之间可以使用逗号进行分割,也意味着3既可以没有逗号,也可以有逗号:vec![1, 2, 3,]*说明之前的模式可以出现零次也可以任意次,这里出现了三次
接下来,我们再来看看与模式相关联、在 => 之后的代码:
#![allow(unused)] fn main() { { { let mut temp_vec = Vec::new(); $( temp_vec.push($x); )* temp_vec } }; }
这里就比较好理解了,$() 中的 temp_vec.push() 将根据模式匹配的次数生成对应的代码,当调用 vec![1, 2, 3] 时,下面这段生成的代码将替代传入的源代码,也就是替代 vec![1, 2, 3] :
#![allow(unused)] fn main() { { let mut temp_vec = Vec::new(); temp_vec.push(1); temp_vec.push(2); temp_vec.push(3); temp_vec } }
如果是 let v = vec![1, 2, 3],那生成的代码最后返回的值 temp_vec 将被赋予给变量 v,等同于 :
#![allow(unused)] fn main() { let v = { let mut temp_vec = Vec::new(); temp_vec.push(1); temp_vec.push(2); temp_vec.push(3); temp_vec } }
至此,我们定义了一个宏,它可以接受任意类型和数量的参数,并且理解了其语法的含义。
未来将被替代的 macro_rules
对于 macro_rules 来说,它是存在一些问题的,因此,Rust 计划在未来使用新的声明式宏来替换它:工作方式类似,但是解决了目前存在的一些问题,在那之后,macro_rules 将变为 deprecated 状态。
由于绝大多数 Rust 开发者都是宏的用户而不是编写者,因此在这里我们不会对 macro_rules 进行更深入的学习,如果大家感兴趣,可以看看这本书 “The Little Book of Rust Macros”。
用过程宏为属性标记生成代码
第二种常用的宏就是过程宏 ( procedural macros ),从形式上来看,过程宏跟函数较为相像,但过程宏是使用源代码作为输入参数,基于代码进行一系列操作后,再输出一段全新的代码。注意,过程宏中的 derive 宏输出的代码并不会替换之前的代码,这一点与声明宏有很大的不同!
至于前文提到的过程宏的三种类型(自定义 derive、属性宏、函数宏),它们的工作方式都是类似的。
当创建过程宏时,它的定义必须要放入一个独立的包中,且包的类型也是特殊的,这么做的原因相当复杂,大家只要知道这种限制在未来可能会有所改变即可。
事实上,根据这个说法,过程宏放入独立包的原因在于它必须先被编译后才能使用,如果过程宏和使用它的代码在一个包,就必须先单独对过程宏的代码进行编译,然后再对我们的代码进行编译,但悲剧的是 Rust 的编译单元是包,因此你无法做到这一点。
假设我们要创建一个 derive 类型的过程宏:
#![allow(unused)] fn main() { use proc_macro; #[proc_macro_derive(HelloMacro)] pub fn some_name(input: TokenStream) -> TokenStream { } }
用于定义过程宏的函数 some_name 使用 TokenStream 作为输入参数,并且返回的也是同一个类型。TokenStream 是在 proc_macro 包中定义的,顾名思义,它代表了一个 Token 序列。
在理解了过程宏的基本定义后,我们再来看看该如何创建三种类型的过程宏,首先,从大家最熟悉的 derive 开始。
自定义 derive 过程宏
假设我们有一个特征 HelloMacro,现在有两种方式让用户使用它:
- 为每个类型手动实现该特征,就像之前特征章节所做的
- 使用过程宏来统一实现该特征,这样用户只需要对类型进行标记即可:
#[derive(HelloMacro)]
以上两种方式并没有孰优孰劣,主要在于不同的类型是否可以使用同样的默认特征实现,如果可以,那过程宏的方式可以帮我们减少很多代码实现:
use hello_macro::HelloMacro; use hello_macro_derive::HelloMacro; #[derive(HelloMacro)] struct Sunfei; #[derive(HelloMacro)] struct Sunface; fn main() { Sunfei::hello_macro(); Sunface::hello_macro(); }
简单吗?简单!不过为了实现这段代码展示的功能,我们还需要创建相应的过程宏才行。 首先,创建一个新的工程用于演示:
$ cargo new hello_macro
$ cd hello_macro/
$ touch src/lib.rs
此时,src 目录下包含两个文件 lib.rs 和 main.rs,前者是 lib 包根,后者是二进制包根,如果大家对包根不熟悉,可以看看这里。
接下来,先在 src/lib.rs 中定义过程宏所需的 HelloMacro 特征和其关联函数:
#![allow(unused)] fn main() { pub trait HelloMacro { fn hello_macro(); } }
然后在 src/main.rs 中编写主体代码,首先映入大家脑海的可能会是如下实现:
use hello_macro::HelloMacro; struct Sunfei; impl HelloMacro for Sunfei { fn hello_macro() { println!("Hello, Macro! My name is Sunfei!"); } } struct Sunface; impl HelloMacro for Sunface { fn hello_macro() { println!("Hello, Macro! My name is Sunface!"); } } fn main() { Sunfei::hello_macro(); }
但是这种方式有个问题,如果想要实现不同的招呼内容,就需要为每一个类型都实现一次相应的特征,Rust 不支持反射,因此我们无法在运行时获得类型名。
使用宏,就不存在这个问题:
use hello_macro::HelloMacro; use hello_macro_derive::HelloMacro; #[derive(HelloMacro)] struct Sunfei; #[derive(HelloMacro)] struct Sunface; fn main() { Sunfei::hello_macro(); Sunface::hello_macro(); }
简单明了的代码总是令人愉快,为了让代码运行起来,还需要定义下过程宏。就如前文提到的,目前只能在单独的包中定义过程宏,尽管未来这种限制会被取消,但是现在我们还得遵循这个规则。
宏所在的包名自然也有要求,必须以 derive 为后缀,对于 hello_macro 宏而言,包名就应该是 hello_macro_derive。在之前创建的 hello_macro 项目根目录下,运行如下命令,创建一个单独的 lib 包:
#![allow(unused)] fn main() { cargo new hello_macro_derive --lib }
至此, hello_macro 项目的目录结构如下:
hello_macro
├── Cargo.toml
├── src
│ ├── main.rs
│ └── lib.rs
└── hello_macro_derive
├── Cargo.toml
├── src
└── lib.rs
由于过程宏所在的包跟我们的项目紧密相连,因此将它放在项目之中。现在,问题又来了,该如何在项目的 src/main.rs 中引用 hello_macro_derive 包的内容?
方法有两种,第一种是将 hello_macro_derive 发布到 crates.io 或 GitHub 中,就像我们引用的其它依赖一样;另一种就是使用相对路径引入的本地化方式,修改 hello_macro/Cargo.toml 文件添加以下内容:
[dependencies]
hello_macro_derive = { path = "../hello_macro/hello_macro_derive" }
# 也可以使用下面的相对路径
# hello_macro_derive = { path = "./hello_macro_derive" }
此时,hello_macro 项目就可以成功的引用到 hello_macro_derive 本地包了,对于项目依赖引入的详细介绍,可以参见 Cargo 章节。
另外,学习过程更好的办法是通过展开宏来阅读和调试自己写的宏,这里需要用到一个 cargo-expand 的工具,可以通过下面的命令安装
cargo install cargo-expand
接下来,就到了重头戏环节,一起来看看该如何定义过程宏。
定义过程宏
首先,在 hello_macro_derive/Cargo.toml 文件中添加以下内容:
[lib]
proc-macro = true
[dependencies]
syn = "1.0"
quote = "1.0"
其中 syn 和 quote 依赖包都是定义过程宏所必需的,同时,还需要在 [lib] 中将过程宏的开关开启 : proc-macro = true。
其次,在 hello_macro_derive/src/lib.rs 中添加如下代码:
#![allow(unused)] fn main() { extern crate proc_macro; use proc_macro::TokenStream; use quote::quote; use syn; use syn::DeriveInput; #[proc_macro_derive(HelloMacro)] pub fn hello_macro_derive(input: TokenStream) -> TokenStream { // 基于 input 构建 AST 语法树 let ast:DeriveInput = syn::parse(input).unwrap(); // 构建特征实现代码 impl_hello_macro(&ast) } }
这个函数的签名我们在之前已经介绍过,总之,这种形式的过程宏定义是相当通用的,下面来分析下这段代码。
首先有一点,对于绝大多数过程宏而言,这段代码往往只在 impl_hello_macro(&ast) 中的实现有所区别,对于其它部分基本都是一致的,例如包的引入、宏函数的签名、语法树构建等。
proc_macro 包是 Rust 自带的,因此无需在 Cargo.toml 中引入依赖,它包含了相关的编译器 API,可以用于读取和操作 Rust 源代码。
由于我们为 hello_macro_derive 函数标记了 #[proc_macro_derive(HelloMacro)],当用户使用 #[derive(HelloMacro)] 标记了他的类型后,hello_macro_derive 函数就将被调用。这里的秘诀就是特征名 HelloMacro,它就像一座桥梁,将用户的类型和过程宏联系在一起。
syn 将字符串形式的 Rust 代码解析为一个 AST 树的数据结构,该数据结构可以在随后的 impl_hello_macro 函数中进行操作。最后,操作的结果又会被 quote 包转换回 Rust 代码。这些包非常关键,可以帮我们节省大量的精力,否则你需要自己去编写支持代码解析和还原的解析器,这可不是一件简单的任务!
derive过程宏只能用在struct/enum/union上,多数用在结构体上,我们先来看一下一个结构体由哪些部分组成:
#![allow(unused)] fn main() { // vis,可视范围 ident,标识符 generic,范型 fields: 结构体的字段 pub struct User <'a, T> { // vis ident type pub name: &'a T, } }
其中type还可以细分,具体请阅读syn文档或源码
syn::parse 调用会返回一个 DeriveInput 结构体来代表解析后的 Rust 代码:
#![allow(unused)] fn main() { DeriveInput { // --snip-- vis: Visibility, generics: Generics ident: Ident { ident: "Sunfei", span: #0 bytes(95..103) }, // Data是一个枚举,分别是DataStruct,DataEnum,DataUnion,这里以 DataStruct 为例 data: Data( DataStruct { struct_token: Struct, fields: Fields, semi_token: Some( Semi ) } ) } }
以上就是源代码 struct Sunfei; 解析后的结果,里面有几点值得注意:
fields: Fields是一个枚举类型,FieldsNamed,FieldsUnnamed,FieldsUnnamed, 分别表示显示命名结构(如例子所示),匿名字段的结构(例如 struct A(u8);),和无字段定义的结构(例如 struct A;)ident: "Sunfei"说明类型名称为Sunfei,ident是标识符identifier的简写
如果想要了解更多的信息,可以查看 syn 文档。
大家可能会注意到在 hello_macro_derive 函数中有 unwrap 的调用,也许会以为这是为了演示目的,没有做错误处理,实际上并不是的。由于该函数只能返回 TokenStream 而不是 Result,那么在报错时直接 panic 来抛出错误就成了相当好的选择。当然,这里实际上还是做了简化,在生产项目中,你应该通过 panic! 或 expect 抛出更具体的报错信息。
至此,这个函数大家应该已经基本理解了,下面来看看如何构建特征实现的代码,也是过程宏的核心目标:
#![allow(unused)] fn main() { fn impl_hello_macro(ast: &syn::DeriveInput) -> TokenStream { let name = &ast.ident; let gen = quote! { impl HelloMacro for #name { fn hello_macro() { println!("Hello, Macro! My name is {}!", stringify!(#name)); } } }; gen.into() } }
首先,将结构体的名称赋予给 name,也就是 name 中会包含一个字段,它的值是字符串 "Sunfei"。
其次,使用 quote! 可以定义我们想要返回的 Rust 代码。由于编译器需要的内容和 quote! 直接返回的不一样,因此还需要使用 .into 方法其转换为 TokenStream。
大家注意到 #name 的使用了吗?这也是 quote! 提供的功能之一,如果想要深入了解 quote,可以看看官方文档。
特征的 hell_macro() 函数只有一个功能,就是使用 println! 打印一行欢迎语句。
其中 stringify! 是 Rust 提供的内置宏,可以将一个表达式(例如 1 + 2)在编译期转换成一个字符串字面值("1 + 2"),该字面量会直接打包进编译出的二进制文件中,具有 'static 生命周期。而 format! 宏会对表达式进行求值,最终结果是一个 String 类型。在这里使用 stringify! 有两个好处:
#name可能是一个表达式,我们需要它的字面值形式- 可以减少一次
String带来的内存分配
在运行之前,可以显示用 expand 展开宏,观察是否有错误或是否符合预期:
$ cargo expand
struct Sunfei; impl HelloMacro for Sunfei { fn hello_macro() { { ::std::io::_print( ::core::fmt::Arguments::new_v1( &["Hello, Macro! My name is ", "!\n"], &[::core::fmt::ArgumentV1::new_display(&"Sunfei")], ), ); }; } } struct Sunface; impl HelloMacro for Sunface { fn hello_macro() { { ::std::io::_print( ::core::fmt::Arguments::new_v1( &["Hello, Macro! My name is ", "!\n"], &[::core::fmt::ArgumentV1::new_display(&"Sunface")], ), ); }; } } fn main() { Sunfei::hello_macro(); Sunface::hello_macro(); }
从展开的代码也能看出derive宏的特性,struct Sunfei; 和 struct Sunface; 都被保留了,也就是说最后 impl_hello_macro() 返回的token被加到结构体后面,这和类属性宏可以修改输入 的token是不一样的,input的token并不能被修改
至此,过程宏的定义、特征定义、主体代码都已经完成,运行下试试:
$ cargo run
Running `target/debug/hello_macro`
Hello, Macro! My name is Sunfei!
Hello, Macro! My name is Sunface!
Bingo,虽然过程有些复杂,但是结果还是很喜人,我们终于完成了自己的第一个过程宏!
下面来实现一个更实用的例子,实现官方的#[derive(Default)]宏,废话不说直接开干:
#![allow(unused)] fn main() { extern crate proc_macro; use proc_macro::TokenStream; use quote::quote; use syn::{self, Data}; use syn::DeriveInput; #[proc_macro_derive(MyDefault)] pub fn my_default(input: TokenStream) -> TokenStream { let ast: DeriveInput = syn::parse(input).unwrap(); let id = ast.ident; let Data::Struct(s) = ast.data else{ panic!("MyDefault derive macro must use in struct"); }; // 声明一个新的ast,用于动态构建字段赋值的token let mut field_ast = quote!(); // 这里就是要动态添加token的地方了,需要动态完成Self的字段赋值 for (idx,f) in s.fields.iter().enumerate() { let (field_id, field_ty) = (&f.ident, &f.ty); if field_id.is_none(){ //没有ident表示是匿名字段,对于匿名字段,都需要添加 `#field_idx: #field_type::default(),` 这样的代码 let field_idx = syn::Index::from(idx); field_ast.extend(quote! { field_idx: # field_ty::default(), }); }else{ //对于命名字段,都需要添加 `#field_name: #field_type::default(),` 这样的代码 field_ast.extend(quote! { field_id: # field_ty::default(), }); } } quote! { impl Default for # id { fn default() -> Self { Self { field_ast } } } }.into() } }
然后来写使用代码:
#[derive(MyDefault)] struct SomeData (u32,String); #[derive(MyDefault)] struct User { name: String, data: SomeData, } fn main() { }
然后我们先展开代码看一看
struct SomeData(u32, String); impl Default for SomeData { fn default() -> Self { Self { 0: u32::default(), 1: String::default(), } } } struct User { name: String, data: SomeData, } impl Default for User { fn default() -> Self { Self { name: String::default(), data: SomeData::default(), } } } fn main() {}
展开的代码符合预期,然后我们修改一下使用代码并测试结果
#[derive(MyDefault, Debug)] struct SomeData (u32,String); #[derive(MyDefault, Debug)] struct User { name: String, data: SomeData, } fn main() { println!("{:?}", User::default()); }
执行
$ cargo run
Running `target/debug/aaa`
User { name: "", data: SomeData(0, "") }
接下来,再来看看过程宏的另外两种类型跟 derive 类型有何区别。
类属性宏(Attribute-like macros)
类属性过程宏跟 derive 宏类似,但是前者允许我们定义自己的属性。除此之外,derive 只能用于结构体和枚举,而类属性宏可以用于其它类型项,例如函数。
假设我们在开发一个 web 框架,当用户通过 HTTP GET 请求访问 / 根路径时,使用 index 函数为其提供服务:
#![allow(unused)] fn main() { #[route(GET, "/")] fn index() { }
如上所示,代码功能非常清晰、简洁,这里的 #[route] 属性就是一个过程宏,它的定义函数大概如下:
#![allow(unused)] fn main() { #[proc_macro_attribute] pub fn route(attr: TokenStream, item: TokenStream) -> TokenStream { }
与 derive 宏不同,类属性宏的定义函数有两个参数:
- 第一个参数时用于说明属性包含的内容:
Get, "/"部分 - 第二个是属性所标注的类型项,在这里是
fn index() {...},注意,函数体也被包含其中
除此之外,类属性宏跟 derive 宏的工作方式并无区别:创建一个包,类型是 proc-macro,接着实现一个函数用于生成想要的代码。
类函数宏(Function-like macros)
类函数宏可以让我们定义像函数那样调用的宏,从这个角度来看,它跟声明宏 macro_rules 较为类似。
区别在于,macro_rules 的定义形式与 match 匹配非常相像,而类函数宏的定义形式则类似于之前讲过的两种过程宏:
#![allow(unused)] fn main() { #[proc_macro] pub fn sql(input: TokenStream) -> TokenStream { }
而使用形式则类似于函数调用:
#![allow(unused)] fn main() { let sql = sql!(SELECT * FROM posts WHERE id=1); }
大家可能会好奇,为何我们不使用声明宏 macro_rules 来定义呢?原因是这里需要对 SQL 语句进行解析并检查其正确性,这个复杂的过程是 macro_rules 难以对付的,而过程宏相比起来就会灵活的多。
补充学习资料
- dtolnay/proc-macro-workshop,学习如何编写过程宏
- The Little Book of Rust Macros,学习如何编写声明宏
macro_rules! - syn 和 quote ,用于编写过程宏的包,它们的文档有很多值得学习的东西
- Structuring, testing and debugging procedural macro crates,从测试、debug、结构化的角度来编写过程宏
- blog.turbo.fish,里面的过程宏系列文章值得一读
- Rust 宏小册中文版,非常详细的解释了宏各种知识
总结
Rust 中的宏主要分为两大类:声明宏和过程宏。
声明宏目前使用 macro_rules 进行创建,它的形式类似于 match 匹配,对于用户而言,可读性和维护性都较差。由于其存在的问题和限制,在未来, macro_rules 会被 deprecated,Rust 会使用一个新的声明宏来替代它。
而过程宏的定义更像是我们平时写函数的方式,因此它更加灵活,它分为三种类型:derive 宏、类属性宏、类函数宏,具体在文中都有介绍。
虽然 Rust 中的宏很强大,但是它并不应该成为我们的常规武器,原因是它会影响 Rust 代码的可读性和可维护性,我相信没有几个人愿意去维护别人写的宏 :)
因此,大家应该熟悉宏的使用场景,但是不要滥用,当你真的需要时,再回来查看本章了解实现细节,这才是最完美的使用方式。
Rust 异步编程
在艰难的学完 Rust 入门和进阶所有的 70 个章节后,我们终于来到了这里。假如之前攀登的是珠穆朗玛峰,那么现在攀登的就是乔戈里峰( 比珠峰还难攀爬... )。
如果你想开发 Web 服务器、数据库驱动、消息服务等需要高并发的服务,那么本章的内容将值得认真对待和学习,将从以下方面深入讲解 Rust 的异步编程:
- Rust 异步编程的通用概念介绍
- Future 以及异步任务调度
- async/await 和 Pin/Unpin
- 异步编程常用的三方库
- tokio 库
- 一些示例
异步编程
接下来,我们将深入了解 async/await 的使用方式及背后的原理。
本章在内容上大量借鉴和翻译了原版英文书籍Asynchronous Programming In Rust, 特此感谢
Async 编程简介
众所周知,Rust 可以让我们写出性能高且安全的软件,那么异步编程这块儿呢?是否依然在高性能的同时保证了安全?
我们先通过一张 web 框架性能对比图来感受下 Rust 异步编程的性能:
上图并不能说 Rust 写的 actix 框架比 Go 的 gin 更好、更优秀,但是确实可以一定程度上说明 Rust 的异步性能非常的高!
简单来说,异步编程是一个并发编程模型,目前主流语言基本都支持了,当然,支持的方式有所不同。异步编程允许我们同时并发运行大量的任务,却仅仅需要几个甚至一个 OS 线程或 CPU 核心,现代化的异步编程在使用体验上跟同步编程也几无区别,例如 Go 语言的 go 关键字,也包括我们后面将介绍的 async/await 语法,该语法是 JavaScript 和 Rust 的核心特性之一。
async 简介
async 是 Rust 选择的异步编程模型,下面我们来介绍下它的优缺点,以及何时适合使用。
async vs 其它并发模型
由于并发编程在现代社会非常重要,因此每个主流语言都对自己的并发模型进行过权衡取舍和精心设计,Rust 语言也不例外。下面的列表可以帮助大家理解不同并发模型的取舍:
- OS 线程, 它最简单,也无需改变任何编程模型(业务/代码逻辑),因此非常适合作为语言的原生并发模型,我们在多线程章节也提到过,Rust 就选择了原生支持线程级的并发编程。但是,这种模型也有缺点,例如线程间的同步将变得更加困难,线程间的上下文切换损耗较大。使用线程池在一定程度上可以提升性能,但是对于 IO 密集的场景来说,线程池还是不够。
- 事件驱动(Event driven), 这个名词你可能比较陌生,如果说事件驱动常常跟回调( Callback )一起使用,相信大家就恍然大悟了。这种模型性能相当的好,但最大的问题就是存在回调地狱的风险:非线性的控制流和结果处理导致了数据流向和错误传播变得难以掌控,还会导致代码可维护性和可读性的大幅降低,大名鼎鼎的
JavaScript曾经就存在回调地狱。 - 协程(Coroutines) 可能是目前最火的并发模型,
Go语言的协程设计就非常优秀,这也是Go语言能够迅速火遍全球的杀手锏之一。协程跟线程类似,无需改变编程模型,同时,它也跟async类似,可以支持大量的任务并发运行。但协程抽象层次过高,导致用户无法接触到底层的细节,这对于系统编程语言和自定义异步运行时是难以接受的 - actor 模型是 erlang 的杀手锏之一,它将所有并发计算分割成一个一个单元,这些单元被称为
actor, 单元之间通过消息传递的方式进行通信和数据传递,跟分布式系统的设计理念非常相像。由于actor模型跟现实很贴近,因此它相对来说更容易实现,但是一旦遇到流控制、失败重试等场景时,就会变得不太好用 - async/await, 该模型性能高,还能支持底层编程,同时又像线程和协程那样无需过多的改变编程模型,但有得必有失,
async模型的问题就是内部实现机制过于复杂,对于用户来说,理解和使用起来也没有线程和协程简单,好在前者的复杂性开发者们已经帮我们封装好,而理解和使用起来不够简单,正是本章试图解决的问题。
总之,Rust 经过权衡取舍后,最终选择了同时提供多线程编程和 async 编程:
- 前者通过标准库实现,当你无需那么高的并发时,例如需要并行计算时,可以选择它,优点是线程内的代码执行效率更高、实现更直观更简单,这块内容已经在多线程章节进行过深入讲解,不再赘述
- 后者通过语言特性 + 标准库 + 三方库的方式实现,在你需要高并发、异步
I/O时,选择它就对了
async: Rust vs 其它语言
目前已经有诸多语言都通过 async 的方式提供了异步编程,例如 JavaScript ,但 Rust 在实现上有所区别:
- Future 在 Rust 中是惰性的,只有在被轮询(
poll)时才会运行, 因此丢弃一个future会阻止它未来再被运行, 你可以将Future理解为一个在未来某个时间点被调度执行的任务。 - Async 在 Rust 中使用开销是零, 意味着只有你能看到的代码(自己的代码)才有性能损耗,你看不到的(
async内部实现)都没有性能损耗,例如,你可以无需分配任何堆内存、也无需任何动态分发来使用async,这对于热点路径的性能有非常大的好处,正是得益于此,Rust 的异步编程性能才会这么高。 - Rust 没有内置异步调用所必需的运行时,但是无需担心,Rust 社区生态中已经提供了非常优异的运行时实现,例如大明星
tokio - 运行时同时支持单线程和多线程,这两者拥有各自的优缺点,稍后会讲
Rust: async vs 多线程
虽然 async 和多线程都可以实现并发编程,后者甚至还能通过线程池来增强并发能力,但是这两个方式并不互通,从一个方式切换成另一个需要大量的代码重构工作,因此提前为自己的项目选择适合的并发模型就变得至关重要。
OS 线程非常适合少量任务并发,因为线程的创建和上下文切换是非常昂贵的,甚至于空闲的线程都会消耗系统资源。虽说线程池可以有效的降低性能损耗,但是也无法彻底解决问题。当然,线程模型也有其优点,例如它不会破坏你的代码逻辑和编程模型,你之前的顺序代码,经过少量修改适配后依然可以在新线程中直接运行,同时在某些操作系统中,你还可以改变线程的优先级,这对于实现驱动程序或延迟敏感的应用(例如硬实时系统)很有帮助。
对于长时间运行的 CPU 密集型任务,例如并行计算,使用线程将更有优势。 这种密集任务往往会让所在的线程持续运行,任何不必要的线程切换都会带来性能损耗,因此高并发反而在此时成为了一种多余。同时你所创建的线程数应该等于 CPU 核心数,充分利用 CPU 的并行能力,甚至还可以将线程绑定到 CPU 核心上,进一步减少线程上下文切换。
而高并发更适合 IO 密集型任务,例如 web 服务器、数据库连接等等网络服务,因为这些任务绝大部分时间都处于等待状态,如果使用多线程,那线程大量时间会处于无所事事的状态,再加上线程上下文切换的高昂代价,让多线程做 IO 密集任务变成了一件非常奢侈的事。而使用async,既可以有效的降低 CPU 和内存的负担,又可以让大量的任务并发的运行,一个任务一旦处于IO或者其他等待(阻塞)状态,就会被立刻切走并执行另一个任务,而这里的任务切换的性能开销要远远低于使用多线程时的线程上下文切换。
事实上, async 底层也是基于线程实现,但是它基于线程封装了一个运行时,可以将多个任务映射到少量线程上,然后将线程切换变成了任务切换,后者仅仅是内存中的访问,因此要高效的多。
不过async也有其缺点,原因是编译器会为async函数生成状态机,然后将整个运行时打包进来,这会造成我们编译出的二进制可执行文件体积显著增大。
总之,async编程并没有比多线程更好,最终还是根据你的使用场景作出合适的选择,如果无需高并发,或者也不在意线程切换带来的性能损耗,那么多线程使用起来会简单、方便的多!最后再简单总结下:
若大家使用 tokio,那 CPU 密集的任务尤其需要用线程的方式去处理,例如使用
spawn_blocking创建一个阻塞的线程去完成相应 CPU 密集任务。至于具体的原因,不仅是上文说到的那些,还有一个是:tokio 是协作式地调度器,如果某个 CPU 密集的异步任务是通过 tokio 创建的,那理论上来说,该异步任务需要跟其它的异步任务交错执行,最终大家都得到了执行,皆大欢喜。但实际情况是,CPU 密集的任务很可能会一直霸占着 CPU,此时 tokio 的调度方式决定了该任务会一直被执行,这意味着,其它的异步任务无法得到执行的机会,最终这些任务都会因为得不到资源而饿死。
而使用
spawn_blocking后,会创建一个单独的 OS 线程,该线程并不会被 tokio 所调度( 被 OS 所调度 ),因此它所执行的 CPU 密集任务也不会导致 tokio 调度的那些异步任务被饿死
- 有大量
IO任务需要并发运行时,选async模型 - 有部分
IO任务需要并发运行时,选多线程,如果想要降低线程创建和销毁的开销,可以使用线程池 - 有大量
CPU密集任务需要并行运行时,例如并行计算,选多线程模型,且让线程数等于或者稍大于CPU核心数 - 无所谓时,统一选多线程
async 和多线程的性能对比
| 操作 | async | 线程 |
|---|---|---|
| 创建 | 0.3 微秒 | 17 微秒 |
| 线程切换 | 0.2 微秒 | 1.7 微秒 |
可以看出,async 在线程切换的开销显著低于多线程,对于 IO 密集的场景,这种性能开销累计下来会非常可怕!
一个例子
在大概理解async后,我们再来看一个简单的例子。如果想并发的下载文件,你可以使用多线程如下实现:
#![allow(unused)] fn main() { fn get_two_sites() { // 创建两个新线程执行任务 let thread_one = thread::spawn(|| download("https://course.rs")); let thread_two = thread::spawn(|| download("https://fancy.rs")); // 等待两个线程的完成 thread_one.join().expect("thread one panicked"); thread_two.join().expect("thread two panicked"); } }
如果是在一个小项目中简单的去下载文件,这么写没有任何问题,但是一旦下载文件的并发请求多起来,那一个下载任务占用一个线程的模式就太重了,会很容易成为程序的瓶颈。好在,我们可以使用async的方式来解决:
#![allow(unused)] fn main() { async fn get_two_sites_async() { // 创建两个不同的`future`,你可以把`future`理解为未来某个时刻会被执行的计划任务 // 当两个`future`被同时执行后,它们将并发的去下载目标页面 let future_one = download_async("https://www.foo.com"); let future_two = download_async("https://www.bar.com"); // 同时运行两个`future`,直至完成 join!(future_one, future_two); } }
此时,不再有线程创建和切换的昂贵开销,所有的函数都是通过静态的方式进行分发,同时也没有任何内存分配发生。这段代码的性能简直无懈可击!
事实上,async 和多线程并不是二选一,在同一应用中,可以根据情况两者一起使用,当然,我们还可以使用其它的并发模型,例如上面提到事件驱动模型,前提是有三方库提供了相应的实现。
Async Rust 当前的进展
简而言之,Rust 语言的 async 目前还没有达到多线程的成熟度,其中一部分内容还在不断进化中,当然,这并不影响我们在生产级项目中使用,因为社区中还有 tokio 这种大杀器。
使用 async 时,你会遇到好的,也会遇到不好的,例如:
- 收获卓越的性能
- 会经常跟进阶语言特性打交道,例如生命周期等,这些家伙可不好对付
- 一些兼容性问题,例如同步和异步代码、不同的异步运行时(
tokio与async-std) - 更昂贵的维护成本,原因是
async和社区开发的运行时依然在不停的进化
总之,async 在 Rust 中并不是一个善茬,你会遇到更多的困难或者说坑,也会带来更高的代码阅读成本及维护成本,但是为了性能,一切都值了,不是吗?
不过好在,这些进化早晚会彻底稳定成熟,而且在实际项目中,我们往往会使用成熟的三方库,例如tokio,因此可以避免一些类似的问题,但是对于本章的学习来说,async 的一些难点还是我们必须要去面对和征服的。
语言和库的支持
async 的底层实现非常复杂,且会导致编译后文件体积显著增加,因此 Rust 没有选择像 Go 语言那样内置了完整的特性和运行时,而是选择了通过 Rust 语言提供了必要的特性支持,再通过社区来提供 async 运行时的支持。 因此要完整的使用 async 异步编程,你需要依赖以下特性和外部库:
- 所必须的特征(例如
Future)、类型和函数,由标准库提供实现 - 关键字
async/await由 Rust 语言提供,并进行了编译器层面的支持 - 众多实用的类型、宏和函数由官方开发的
futures包提供(不是标准库),它们可以用于任何async应用中。 async代码的执行、IO操作、任务创建和调度等等复杂功能由社区的async运行时提供,例如tokio和async-std
还有,你在同步( synchronous )代码中使用的一些语言特性在 async 中可能将无法再使用,而且 Rust 也不允许你在特征中声明 async 函数(可以通过三方库实现), 总之,你会遇到一些在同步代码中不会遇到的奇奇怪怪、形形色色的问题,不过不用担心,本章会专门用一个章节罗列这些问题,并给出相应的解决方案。
编译和错误
在大多数情况下,async 中的编译错误和运行时错误跟之前没啥区别,但是依然有以下几点值得注意:
- 编译错误,由于
async编程时需要经常使用复杂的语言特性,例如生命周期和Pin,因此相关的错误可能会出现的更加频繁 - 运行时错误,编译器会为每一个
async函数生成状态机,这会导致在栈跟踪时会包含这些状态机的细节,同时还包含了运行时对函数的调用,因此,栈跟踪记录(例如panic时)将变得更加难以解读 - 一些隐蔽的错误也可能发生,例如在一个
async上下文中去调用一个阻塞的函数,或者没有正确的实现Future特征都有可能导致这种错误。这种错误可能会悄无声息的通过编译检查甚至有时候会通过单元测试。好在一旦你深入学习并掌握了本章的内容和async原理,可以有效的降低遇到这些错误的概率
兼容性考虑
异步代码和同步代码并不总能和睦共处。例如,我们无法在一个同步函数中去调用一个 async 异步函数,同步和异步代码也往往使用不同的设计模式,这些都会导致两者融合上的困难。
甚至于有时候,异步代码之间也存在类似的问题,如果一个库依赖于特定的 async 运行时来运行,那么这个库非常有必要告诉它的用户,它用了这个运行时。否则一旦用户选了不同的或不兼容的运行时,就会导致不可预知的麻烦。
性能特性
async 代码的性能主要取决于你使用的 async 运行时,好在这些运行时都经过了精心的设计,在你能遇到的绝大多数场景中,它们都能拥有非常棒的性能表现。
但是世事皆有例外。目前主流的 async 运行时几乎都使用了多线程实现,相比单线程虽然增加了并发表现,但是对于执行性能会有所损失,因为多线程实现会有同步和切换上的性能开销,若你需要极致的顺序执行性能,那么 async 目前并不是一个好的选择。
同样的,对于延迟敏感的任务来说,任务的执行次序需要能被严格掌控,而不是交由运行时去自动调度,后者会导致不可预知的延迟,例如一个 web 服务器总是有 1% 的请求,它们的延迟会远高于其它请求,因为调度过于繁忙导致了部分任务被延迟调度,最终导致了较高的延时。正因为此,这些延迟敏感的任务非常依赖于运行时或操作系统提供调度次序上的支持。
以上的两个需求,目前的 async 运行时并不能很好的支持,在未来可能会有更好的支持,但在此之前,我们可以尝试用多线程解决。
async/.await 简单入门
async/.await 是 Rust 内置的语言特性,可以让我们用同步的方式去编写异步的代码。
通过 async 标记的语法块会被转换成实现了Future特征的状态机。 与同步调用阻塞当前线程不同,当Future执行并遇到阻塞时,它会让出当前线程的控制权,这样其它的Future就可以在该线程中运行,这种方式完全不会导致当前线程的阻塞。
下面我们来通过例子学习 async/.await 关键字该如何使用,在开始之前,需要先引入 futures 包。编辑 Cargo.toml 文件并添加以下内容:
[dependencies]
futures = "0.3"
使用 async
首先,使用 async fn 语法来创建一个异步函数:
#![allow(unused)] fn main() { async fn do_something() { println!("go go go !"); } }
需要注意,异步函数的返回值是一个 Future,若直接调用该函数,不会输出任何结果,因为 Future 还未被执行:
fn main() { do_something(); }
运行后,go go go并没有打印,同时编译器给予一个提示:warning: unused implementer of Future that must be used,告诉我们 Future 未被使用,那么到底该如何使用?答案是使用一个执行器( executor ):
// `block_on`会阻塞当前线程直到指定的`Future`执行完成,这种阻塞当前线程以等待任务完成的方式较为简单、粗暴, // 好在其它运行时的执行器(executor)会提供更加复杂的行为,例如将多个`future`调度到同一个线程上执行。 use futures::executor::block_on; async fn hello_world() { println!("hello, world!"); } fn main() { let future = hello_world(); // 返回一个Future, 因此不会打印任何输出 block_on(future); // 执行`Future`并等待其运行完成,此时"hello, world!"会被打印输出 }
使用.await
在上述代码的main函数中,我们使用block_on这个执行器等待Future的完成,让代码看上去非常像是同步代码,但是如果你要在一个async fn函数中去调用另一个async fn并等待其完成后再执行后续的代码,该如何做?例如:
use futures::executor::block_on; async fn hello_world() { hello_cat(); println!("hello, world!"); } async fn hello_cat() { println!("hello, kitty!"); } fn main() { let future = hello_world(); block_on(future); }
这里,我们在hello_world异步函数中先调用了另一个异步函数hello_cat,然后再输出hello, world!,看看运行结果:
warning: unused implementer of `futures::Future` that must be used
--> src/main.rs:6:5
|
6 | hello_cat();
| ^^^^^^^^^^^^
= note: futures do nothing unless you `.await` or poll them
...
hello, world!
不出所料,main函数中的future我们通过block_on函数进行了运行,但是这里的hello_cat返回的Future却没有任何人去执行它,不过好在编译器友善的给出了提示:futures do nothing unless you `.await` or poll them,两种解决方法:使用.await语法或者对Future进行轮询(poll)。
后者较为复杂,暂且不表,先来使用.await试试:
use futures::executor::block_on; async fn hello_world() { hello_cat().await; println!("hello, world!"); } async fn hello_cat() { println!("hello, kitty!"); } fn main() { let future = hello_world(); block_on(future); }
为hello_cat()添加上.await后,结果立刻大为不同:
hello, kitty!
hello, world!
输出的顺序跟代码定义的顺序完全符合,因此,我们在上面代码中使用同步的代码顺序实现了异步的执行效果,非常简单、高效,而且很好理解,未来也绝对不会有回调地狱的发生。
总之,在async fn函数中使用.await可以等待另一个异步调用的完成。但是与block_on不同,.await并不会阻塞当前的线程,而是异步的等待Future A的完成,在等待的过程中,该线程还可以继续执行其它的Future B,最终实现了并发处理的效果。
一个例子
考虑一个载歌载舞的例子,如果不用.await,我们可能会有如下实现:
use futures::executor::block_on; struct Song { author: String, name: String, } async fn learn_song() -> Song { Song { author: "周杰伦".to_string(), name: String::from("《菊花台》"), } } async fn sing_song(song: Song) { println!( "给大家献上一首{}的{} ~ {}", song.author, song.name, "菊花残,满地伤~ ~" ); } async fn dance() { println!("唱到情深处,身体不由自主的动了起来~ ~"); } fn main() { let song = block_on(learn_song()); block_on(sing_song(song)); block_on(dance()); }
当然,以上代码运行结果无疑是正确的,但。。。它的性能何在?需要通过连续三次阻塞去等待三个任务的完成,一次只能做一件事,实际上我们完全可以载歌载舞啊:
use futures::executor::block_on; struct Song { author: String, name: String, } async fn learn_song() -> Song { Song { author: "曲婉婷".to_string(), name: String::from("《我的歌声里》"), } } async fn sing_song(song: Song) { println!( "给大家献上一首{}的{} ~ {}", song.author, song.name, "你存在我深深的脑海里~ ~" ); } async fn dance() { println!("唱到情深处,身体不由自主的动了起来~ ~"); } async fn learn_and_sing() { // 这里使用`.await`来等待学歌的完成,但是并不会阻塞当前线程,该线程在学歌的任务`.await`后,完全可以去执行跳舞的任务 let song = learn_song().await; // 唱歌必须要在学歌之后 sing_song(song).await; } async fn async_main() { let f1 = learn_and_sing(); let f2 = dance(); // `join!`可以并发的处理和等待多个`Future`,若`learn_and_sing Future`被阻塞,那`dance Future`可以拿过线程的所有权继续执行。若`dance`也变成阻塞状态,那`learn_and_sing`又可以再次拿回线程所有权,继续执行。 // 若两个都被阻塞,那么`async main`会变成阻塞状态,然后让出线程所有权,并将其交给`main`函数中的`block_on`执行器 futures::join!(f1, f2); } fn main() { block_on(async_main()); }
上面代码中,学歌和唱歌具有明显的先后顺序,但是这两者都可以跟跳舞一同存在,也就是你可以在跳舞的时候学歌,也可以在跳舞的时候唱歌。如果上面代码不使用.await,而是使用block_on(learn_song()), 那在学歌时,当前线程就会阻塞,不再可以做其它任何事,包括跳舞。
因此.await对于实现异步编程至关重要,它允许我们在同一个线程内并发的运行多个任务,而不是一个一个先后完成。若大家看到这里还是不太明白,强烈建议回头再仔细看一遍,同时亲自上手修改代码试试效果。
至此,读者应该对 Rust 的async/.await异步编程有了一个清晰的初步印象,下面让我们一起来看看这背后的原理:Future和任务在底层如何被执行。
底层探秘: Future 执行器与任务调度
异步编程背后到底藏有什么秘密?究竟是哪只幕后之手在操纵这一切?如果你对这些感兴趣,就继续看下去,否则可以直接跳过,因为本章节的内容对于一个 API 工程师并没有太多帮助。
但是如果你希望能深入理解 Rust 的 async/.await 代码是如何工作、理解运行时和性能,甚至未来想要构建自己的 async 运行时或相关工具,那么本章节终究不会辜负于你。
Future 特征
Future 特征是 Rust 异步编程的核心,毕竟异步函数是异步编程的核心,而 Future 恰恰是异步函数的返回值和被执行的关键。
首先,来给出 Future 的定义:它是一个能产出值的异步计算(虽然该值可能为空,例如 () )。光看这个定义,可能会觉得很空洞,我们来看看一个简化版的 Future 特征:
#![allow(unused)] fn main() { trait SimpleFuture { type Output; fn poll(&mut self, wake: fn()) -> Poll<Self::Output>; } enum Poll<T> { Ready(T), Pending, } }
在上一章中,我们提到过 Future 需要被执行器poll(轮询)后才能运行,诺,这里 poll 就来了,通过调用该方法,可以推进 Future 的进一步执行,直到被切走为止( 这里不好理解,但是你只需要知道 Future 并不能保证在一次 poll 中就被执行完,后面会详解介绍)。
若在当前 poll 中, Future 可以被完成,则会返回 Poll::Ready(result) ,反之则返回 Poll::Pending, 并且安排一个 wake 函数:当未来 Future 准备好进一步执行时, 该函数会被调用,然后管理该 Future 的执行器(例如上一章节中的block_on函数)会再次调用 poll 方法,此时 Future 就可以继续执行了。
如果没有 wake 方法,那执行器无法知道某个Future是否可以继续被执行,除非执行器定期的轮询每一个 Future ,确认它是否能被执行,但这种作法效率较低。而有了 wake,Future 就可以主动通知执行器,然后执行器就可以精确的执行该 Future。 这种“事件通知 -> 执行”的方式要远比定期对所有 Future 进行一次全遍历来的高效。
也许大家还是迷迷糊糊的,没事,我们用一个例子来说明下。考虑一个需要从 socket 读取数据的场景:如果有数据,可以直接读取数据并返回 Poll::Ready(data), 但如果没有数据,Future 会被阻塞且不会再继续执行,此时它会注册一个 wake 函数,当 socket 数据准备好时,该函数将被调用以通知执行器:我们的 Future 已经准备好了,可以继续执行。
下面的 SocketRead 结构体就是一个 Future:
#![allow(unused)] fn main() { pub struct SocketRead<'a> { socket: &'a Socket, } impl SimpleFuture for SocketRead<'_> { type Output = Vec<u8>; fn poll(&mut self, wake: fn()) -> Poll<Self::Output> { if self.socket.has_data_to_read() { // socket有数据,写入buffer中并返回 Poll::Ready(self.socket.read_buf()) } else { // socket中还没数据 // // 注册一个`wake`函数,当数据可用时,该函数会被调用, // 然后当前Future的执行器会再次调用`poll`方法,此时就可以读取到数据 self.socket.set_readable_callback(wake); Poll::Pending } } } }
这种 Future 模型允许将多个异步操作组合在一起,同时还无需任何内存分配。不仅仅如此,如果你需要同时运行多个 Future或链式调用多个 Future ,也可以通过无内存分配的状态机实现,例如:
#![allow(unused)] fn main() { trait SimpleFuture { type Output; fn poll(&mut self, wake: fn()) -> Poll<Self::Output>; } enum Poll<T> { Ready(T), Pending, } /// 一个SimpleFuture,它会并发地运行两个Future直到它们完成 /// /// 之所以可以并发,是因为两个Future的轮询可以交替进行,一个阻塞,另一个就可以立刻执行,反之亦然 pub struct Join<FutureA, FutureB> { // 结构体的每个字段都包含一个Future,可以运行直到完成. // 如果Future完成后,字段会被设置为 `None`. 这样Future完成后,就不会再被轮询 a: Option<FutureA>, b: Option<FutureB>, } impl<FutureA, FutureB> SimpleFuture for Join<FutureA, FutureB> where FutureA: SimpleFuture<Output = ()>, FutureB: SimpleFuture<Output = ()>, { type Output = (); fn poll(&mut self, wake: fn()) -> Poll<Self::Output> { // 尝试去完成一个 Future `a` if let Some(a) = &mut self.a { if let Poll::Ready(()) = a.poll(wake) { self.a.take(); } } // 尝试去完成一个 Future `b` if let Some(b) = &mut self.b { if let Poll::Ready(()) = b.poll(wake) { self.b.take(); } } if self.a.is_none() && self.b.is_none() { // 两个 Future都已完成 - 我们可以成功地返回了 Poll::Ready(()) } else { // 至少还有一个 Future 没有完成任务,因此返回 `Poll::Pending`. // 当该 Future 再次准备好时,通过调用`wake()`函数来继续执行 Poll::Pending } } } }
上面代码展示了如何同时运行多个 Future, 且在此过程中没有任何内存分配,让并发编程更加高效。 类似的,多个Future也可以一个接一个的连续运行:
#![allow(unused)] fn main() { /// 一个SimpleFuture, 它使用顺序的方式,一个接一个地运行两个Future // // 注意: 由于本例子用于演示,因此功能简单,`AndThenFut` 会假设两个 Future 在创建时就可用了. // 而真实的`Andthen`允许根据第一个`Future`的输出来创建第二个`Future`,因此复杂的多。 pub struct AndThenFut<FutureA, FutureB> { first: Option<FutureA>, second: FutureB, } impl<FutureA, FutureB> SimpleFuture for AndThenFut<FutureA, FutureB> where FutureA: SimpleFuture<Output = ()>, FutureB: SimpleFuture<Output = ()>, { type Output = (); fn poll(&mut self, wake: fn()) -> Poll<Self::Output> { if let Some(first) = &mut self.first { match first.poll(wake) { // 我们已经完成了第一个 Future, 可以将它移除, 然后准备开始运行第二个 Poll::Ready(()) => self.first.take(), // 第一个 Future 还不能完成 Poll::Pending => return Poll::Pending, }; } // 运行到这里,说明第一个Future已经完成,尝试去完成第二个 self.second.poll(wake) } } }
这些例子展示了在不需要内存对象分配以及深层嵌套回调的情况下,该如何使用 Future 特征去表达异步控制流。 在了解了基础的控制流后,我们再来看看真实的 Future 特征有何不同之处。
#![allow(unused)] fn main() { trait Future { type Output; fn poll( // 首先值得注意的地方是,`self`的类型从`&mut self`变成了`Pin<&mut Self>`: self: Pin<&mut Self>, // 其次将`wake: fn()` 修改为 `cx: &mut Context<'_>`: cx: &mut Context<'_>, ) -> Poll<Self::Output>; } }
首先这里多了一个 Pin ,关于它我们会在后面章节详细介绍,现在你只需要知道使用它可以创建一个无法被移动的 Future ,因为无法被移动,因此它将具有固定的内存地址,意味着我们可以存储它的指针(如果内存地址可能会变动,那存储指针地址将毫无意义!),也意味着可以实现一个自引用数据结构: struct MyFut { a: i32, ptr_to_a: *const i32 }。 而对于 async/await 来说,Pin 是不可或缺的关键特性。
其次,从 wake: fn() 变成了 &mut Context<'_> 。意味着 wake 函数可以携带数据了,为何要携带数据?考虑一个真实世界的场景,一个复杂应用例如 web 服务器可能有数千连接同时在线,那么同时就有数千 Future 在被同时管理着,如果不能携带数据,当一个 Future 调用 wake 后,执行器该如何知道是哪个 Future 调用了 wake ,然后进一步去 poll 对应的 Future ?没有办法!那之前的例子为啥就可以使用没有携带数据的 wake ? 因为足够简单,不存在歧义性。
总之,在正式场景要进行 wake ,就必须携带上数据。 而 Context 类型通过提供一个 Waker 类型的值,就可以用来唤醒特定的的任务。
使用 Waker 来唤醒任务
对于 Future 来说,第一次被 poll 时无法完成任务是很正常的。但它需要确保在未来一旦准备好时,可以通知执行器再次对其进行 poll 进而继续往下执行,该通知就是通过 Waker 类型完成的。
Waker 提供了一个 wake() 方法可以用于告诉执行器:相关的任务可以被唤醒了,此时执行器就可以对相应的 Future 再次进行 poll 操作。
构建一个定时器
下面一起来实现一个简单的定时器 Future 。为了让例子尽量简单,当计时器创建时,我们会启动一个线程接着让该线程进入睡眠,等睡眠结束后再通知给 Future 。
注意本例子还会在后面继续使用,因此我们重新创建一个工程来演示:使用 cargo new --lib timer_future 来创建一个新工程,在 lib 包的根路径 src/lib.rs 中添加以下内容:
#![allow(unused)] fn main() { use std::{ future::Future, pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll, Waker}, thread, time::Duration, }; }
继续来实现 Future 定时器,之前提到: 新建线程在睡眠结束后会需要将状态同步给定时器 Future ,由于是多线程环境,我们需要使用 Arc<Mutex<T>> 来作为一个共享状态,用于在新线程和 Future 定时器间共享。
#![allow(unused)] fn main() { pub struct TimerFuture { shared_state: Arc<Mutex<SharedState>>, } /// 在Future和等待的线程间共享状态 struct SharedState { /// 定时(睡眠)是否结束 completed: bool, /// 当睡眠结束后,线程可以用`waker`通知`TimerFuture`来唤醒任务 waker: Option<Waker>, } }
下面给出 Future 的具体实现:
#![allow(unused)] fn main() { impl Future for TimerFuture { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { // 通过检查共享状态,来确定定时器是否已经完成 let mut shared_state = self.shared_state.lock().unwrap(); if shared_state.completed { Poll::Ready(()) } else { // 设置`waker`,这样新线程在睡眠(计时)结束后可以唤醒当前的任务,接着再次对`Future`进行`poll`操作, // // 下面的`clone`每次被`poll`时都会发生一次,实际上,应该是只`clone`一次更加合理。 // 选择每次都`clone`的原因是: `TimerFuture`可以在执行器的不同任务间移动,如果只克隆一次, // 那么获取到的`waker`可能已经被篡改并指向了其它任务,最终导致执行器运行了错误的任务 shared_state.waker = Some(cx.waker().clone()); Poll::Pending } } } }
代码很简单,只要新线程设置了 shared_state.completed = true ,那任务就能顺利结束。如果没有设置,会为当前的任务克隆一份 Waker ,这样新线程就可以使用它来唤醒当前的任务。
最后,再来创建一个 API 用于构建定时器和启动计时线程:
#![allow(unused)] fn main() { impl TimerFuture { /// 创建一个新的`TimerFuture`,在指定的时间结束后,该`Future`可以完成 pub fn new(duration: Duration) -> Self { let shared_state = Arc::new(Mutex::new(SharedState { completed: false, waker: None, })); // 创建新线程 let thread_shared_state = shared_state.clone(); thread::spawn(move || { // 睡眠指定时间实现计时功能 thread::sleep(duration); let mut shared_state = thread_shared_state.lock().unwrap(); // 通知执行器定时器已经完成,可以继续`poll`对应的`Future`了 shared_state.completed = true; if let Some(waker) = shared_state.waker.take() { waker.wake() } }); TimerFuture { shared_state } } } }
至此,一个简单的定时器 Future 就已创建成功,那么该如何使用它呢?相信部分爱动脑筋的读者已经猜到了:我们需要创建一个执行器,才能让程序动起来。
执行器 Executor
Rust 的 Future 是惰性的:只有屁股上拍一拍,它才会努力动一动。其中一个推动它的方式就是在 async 函数中使用 .await 来调用另一个 async 函数,但是这个只能解决 async 内部的问题,那么这些最外层的 async 函数,谁来推动它们运行呢?答案就是我们之前多次提到的执行器 executor 。
执行器会管理一批 Future (最外层的 async 函数),然后通过不停地 poll 推动它们直到完成。 最开始,执行器会先 poll 一次 Future ,后面就不会主动去 poll 了,而是等待 Future 通过调用 wake 函数来通知它可以继续,它才会继续去 poll 。这种wake 通知然后 poll的方式会不断重复,直到 Future 完成。
构建执行器
下面我们将实现一个简单的执行器,它可以同时并发运行多个 Future 。例子中,需要用到 futures 包的 ArcWake 特征,它可以提供一个方便的途径去构建一个 Waker 。编辑 Cargo.toml ,添加下面依赖:
#![allow(unused)] fn main() { [dependencies] futures = "0.3" }
在之前的内容中,我们在 src/lib.rs 中创建了定时器 Future ,现在在 src/main.rs 中来创建程序的主体内容,开始之前,先引入所需的包:
#![allow(unused)] fn main() { use { futures::{ future::{BoxFuture, FutureExt}, task::{waker_ref, ArcWake}, }, std::{ future::Future, sync::mpsc::{sync_channel, Receiver, SyncSender}, sync::{Arc, Mutex}, task::{Context, Poll}, time::Duration, }, // 引入之前实现的定时器模块 timer_future::TimerFuture, }; }
执行器需要从一个消息通道( channel )中拉取事件,然后运行它们。当一个任务准备好后(可以继续执行),它会将自己放入消息通道中,然后等待执行器 poll 。
#![allow(unused)] fn main() { /// 任务执行器,负责从通道中接收任务然后执行 struct Executor { ready_queue: Receiver<Arc<Task>>, } /// `Spawner`负责创建新的`Future`然后将它发送到任务通道中 #[derive(Clone)] struct Spawner { task_sender: SyncSender<Arc<Task>>, } /// 一个Future,它可以调度自己(将自己放入任务通道中),然后等待执行器去`poll` struct Task { /// 进行中的Future,在未来的某个时间点会被完成 /// /// 按理来说`Mutex`在这里是多余的,因为我们只有一个线程来执行任务。但是由于 /// Rust并不聪明,它无法知道`Future`只会在一个线程内被修改,并不会被跨线程修改。因此 /// 我们需要使用`Mutex`来满足这个笨笨的编译器对线程安全的执着。 /// /// 如果是生产级的执行器实现,不会使用`Mutex`,因为会带来性能上的开销,取而代之的是使用`UnsafeCell` future: Mutex<Option<BoxFuture<'static, ()>>>, /// 可以将该任务自身放回到任务通道中,等待执行器的poll task_sender: SyncSender<Arc<Task>>, } fn new_executor_and_spawner() -> (Executor, Spawner) { // 任务通道允许的最大缓冲数(任务队列的最大长度) // 当前的实现仅仅是为了简单,在实际的执行中,并不会这么使用 const MAX_QUEUED_TASKS: usize = 10_000; let (task_sender, ready_queue) = sync_channel(MAX_QUEUED_TASKS); (Executor { ready_queue }, Spawner { task_sender }) } }
下面再来添加一个方法用于生成 Future , 然后将它放入任务通道中:
#![allow(unused)] fn main() { impl Spawner { fn spawn(&self, future: impl Future<Output = ()> + 'static + Send) { let future = future.boxed(); let task = Arc::new(Task { future: Mutex::new(Some(future)), task_sender: self.task_sender.clone(), }); self.task_sender.send(task).expect("任务队列已满"); } } }
在执行器 poll 一个 Future 之前,首先需要调用 wake 方法进行唤醒,然后再由 Waker 负责调度该任务并将其放入任务通道中。创建 Waker 的最简单的方式就是实现 ArcWake 特征,先来为我们的任务实现 ArcWake 特征,这样它们就能被转变成 Waker 然后被唤醒:
#![allow(unused)] fn main() { impl ArcWake for Task { fn wake_by_ref(arc_self: &Arc<Self>) { // 通过发送任务到任务管道的方式来实现`wake`,这样`wake`后,任务就能被执行器`poll` let cloned = arc_self.clone(); arc_self .task_sender .send(cloned) .expect("任务队列已满"); } } }
当任务实现了 ArcWake 特征后,它就变成了 Waker ,在调用 wake() 对其唤醒后会将任务复制一份所有权( Arc ),然后将其发送到任务通道中。最后我们的执行器将从通道中获取任务,然后进行 poll 执行:
#![allow(unused)] fn main() { impl Executor { fn run(&self) { while let Ok(task) = self.ready_queue.recv() { // 获取一个future,若它还没有完成(仍然是Some,不是None),则对它进行一次poll并尝试完成它 let mut future_slot = task.future.lock().unwrap(); if let Some(mut future) = future_slot.take() { // 基于任务自身创建一个 `LocalWaker` let waker = waker_ref(&task); let context = &mut Context::from_waker(&*waker); // `BoxFuture<T>`是`Pin<Box<dyn Future<Output = T> + Send + 'static>>`的类型别名 // 通过调用`as_mut`方法,可以将上面的类型转换成`Pin<&mut dyn Future + Send + 'static>` if future.as_mut().poll(context).is_pending() { // Future还没执行完,因此将它放回任务中,等待下次被poll *future_slot = Some(future); } } } } } }
恭喜!我们终于拥有了自己的执行器,下面再来写一段代码使用该执行器去运行之前的定时器 Future :
fn main() { let (executor, spawner) = new_executor_and_spawner(); // 生成一个任务 spawner.spawn(async { println!("howdy!"); // 创建定时器Future,并等待它完成 TimerFuture::new(Duration::new(2, 0)).await; println!("done!"); }); // drop掉任务,这样执行器就知道任务已经完成,不会再有新的任务进来 drop(spawner); // 运行执行器直到任务队列为空 // 任务运行后,会先打印`howdy!`, 暂停2秒,接着打印 `done!` executor.run(); }
执行器和系统 IO
前面我们一起看过一个使用 Future 从 Socket 中异步读取数据的例子:
#![allow(unused)] fn main() { pub struct SocketRead<'a> { socket: &'a Socket, } impl SimpleFuture for SocketRead<'_> { type Output = Vec<u8>; fn poll(&mut self, wake: fn()) -> Poll<Self::Output> { if self.socket.has_data_to_read() { // socket有数据,写入buffer中并返回 Poll::Ready(self.socket.read_buf()) } else { // socket中还没数据 // // 注册一个`wake`函数,当数据可用时,该函数会被调用, // 然后当前Future的执行器会再次调用`poll`方法,此时就可以读取到数据 self.socket.set_readable_callback(wake); Poll::Pending } } } }
该例子中,Future 将从 Socket 读取数据,若当前还没有数据,则会让出当前线程的所有权,允许执行器去执行其它的 Future 。当数据准备好后,会调用 wake() 函数将该 Future 的任务放入任务通道中,等待执行器的 poll 。
关于该流程已经反复讲了很多次,相信大家应该非常清楚了。然而该例子中还有一个疑问没有解决:
set_readable_callback方法到底是怎么工作的?怎么才能知道socket中的数据已经可以被读取了?
关于第二点,其中一个简单粗暴的方法就是使用一个新线程不停的检查 socket 中是否有了数据,当有了后,就调用 wake() 函数。该方法确实可以满足需求,但是性能着实太低了,需要为每个阻塞的 Future 都创建一个单独的线程!
在现实世界中,该问题往往是通过操作系统提供的 IO 多路复用机制来完成,例如 Linux 中的 epoll,FreeBSD 和 macOS 中的 kqueue ,Windows 中的 IOCP, Fuchisa中的 ports 等(可以通过 Rust 的跨平台包 mio 来使用它们)。借助 IO 多路复用机制,可以实现一个线程同时阻塞地去等待多个异步 IO 事件,一旦某个事件完成就立即退出阻塞并返回数据。相关实现类似于以下代码:
#![allow(unused)] fn main() { struct IoBlocker { /* ... */ } struct Event { // Event的唯一ID,该事件发生后,就会被监听起来 id: usize, // 一组需要等待或者已发生的信号 signals: Signals, } impl IoBlocker { /// 创建需要阻塞等待的异步IO事件的集合 fn new() -> Self { /* ... */ } /// 对指定的IO事件表示兴趣 fn add_io_event_interest( &self, /// 事件所绑定的socket io_object: &IoObject, event: Event, ) { /* ... */ } /// 进入阻塞,直到某个事件出现 fn block(&self) -> Event { /* ... */ } } let mut io_blocker = IoBlocker::new(); io_blocker.add_io_event_interest( &socket_1, Event { id: 1, signals: READABLE }, ); io_blocker.add_io_event_interest( &socket_2, Event { id: 2, signals: READABLE | WRITABLE }, ); let event = io_blocker.block(); // 当socket的数据可以读取时,打印 "Socket 1 is now READABLE" println!("Socket {:?} is now {:?}", event.id, event.signals); }
这样,我们只需要一个执行器线程,它会接收 IO 事件并将其分发到对应的 Waker 中,接着后者会唤醒相关的任务,最终通过执行器 poll 后,任务可以顺利的继续执行, 这种 IO 读取流程可以不停的循环,直到 socket 关闭。
定海神针 Pin 和 Unpin
在 Rust 异步编程中,有一个定海神针般的存在,它就是 Pin ,作用说简单也简单,说复杂也非常复杂,当初刚出来时就连一些 Rust 大佬都一头雾水,何况瑟瑟发抖的我。好在今非昔比,目前网上的资料已经很全,而我就借花献佛,给大家好好讲讲这个Pin。
在 Rust 中,所有的类型可以分为两类:
- 类型的值可以在内存中安全地被移动,例如数值、字符串、布尔值、结构体、枚举,总之你能想到的几乎所有类型都可以落入到此范畴内
- 自引用类型,大魔王来了,大家快跑,在之前章节我们已经见识过它的厉害
下面就是一个自引用类型
#![allow(unused)] fn main() { struct SelfRef { value: String, pointer_to_value: *mut String, } }
在上面的结构体中,pointer_to_value 是一个裸指针,指向第一个字段 value 持有的字符串 String 。很简单对吧?现在考虑一个情况, 若String 被移动了怎么办?
此时一个致命的问题就出现了:新的字符串的内存地址变了,而 pointer_to_value 依然指向之前的地址,一个重大 bug 就出现了!
灾难发生,英雄在哪?只见 Pin 闪亮登场,它可以防止一个类型在内存中被移动。再来回忆下之前在 Future 章节中,我们提到过在 poll 方法的签名中有一个 self: Pin<&mut Self> ,那么为何要在这里使用 Pin 呢?
为何需要 Pin
其实 Pin 还有一个小伙伴 UnPin ,与前者相反,后者表示类型可以在内存中安全地移动。在深入之前,我们先来回忆下 async/.await 是如何工作的:
#![allow(unused)] fn main() { let fut_one = /* ... */; // Future 1 let fut_two = /* ... */; // Future 2 async move { fut_one.await; fut_two.await; } }
在底层,async 会创建一个实现了 Future 的匿名类型,并提供了一个 poll 方法:
#![allow(unused)] fn main() { // `async { ... }`语句块创建的 `Future` 类型 struct AsyncFuture { fut_one: FutOne, fut_two: FutTwo, state: State, } // `async` 语句块可能处于的状态 enum State { AwaitingFutOne, AwaitingFutTwo, Done, } impl Future for AsyncFuture { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { loop { match self.state { State::AwaitingFutOne => match self.fut_one.poll(..) { Poll::Ready(()) => self.state = State::AwaitingFutTwo, Poll::Pending => return Poll::Pending, } State::AwaitingFutTwo => match self.fut_two.poll(..) { Poll::Ready(()) => self.state = State::Done, Poll::Pending => return Poll::Pending, } State::Done => return Poll::Ready(()), } } } } }
当 poll 第一次被调用时,它会去查询 fut_one 的状态,若 fut_one 无法完成,则 poll 方法会返回。未来对 poll 的调用将从上一次调用结束的地方开始。该过程会一直持续,直到 Future 完成为止。
然而,如果我们的 async 语句块中使用了引用类型,会发生什么?例如下面例子:
#![allow(unused)] fn main() { async { let mut x = [0; 128]; let read_into_buf_fut = read_into_buf(&mut x); read_into_buf_fut.await; println!("{:?}", x); } }
这段代码会编译成下面的形式:
#![allow(unused)] fn main() { struct ReadIntoBuf<'a> { buf: &'a mut [u8], // 指向下面的`x`字段 } struct AsyncFuture { x: [u8; 128], read_into_buf_fut: ReadIntoBuf<'what_lifetime?>, } }
这里,ReadIntoBuf 拥有一个引用字段,指向了结构体的另一个字段 x ,一旦 AsyncFuture 被移动,那 x 的地址也将随之变化,此时对 x 的引用就变成了不合法的,也就是 read_into_buf_fut.buf 会变为不合法的。
若能将 Future 在内存中固定到一个位置,就可以避免这种问题的发生,也就可以安全的创建上面这种引用类型。
Unpin
事实上,绝大多数类型都不在意是否被移动(开篇提到的第一种类型),因此它们都自动实现了 Unpin 特征。
从名字推测,大家可能以为 Pin 和 Unpin 都是特征吧?实际上,Pin 不按套路出牌,它是一个结构体:
#![allow(unused)] fn main() { pub struct Pin<P> { pointer: P, } }
它包裹一个指针,并且能确保该指针指向的数据不会被移动,例如 Pin<&mut T> , Pin<&T> , Pin<Box<T>> ,都能确保 T 不会被移动。
而 Unpin 才是一个特征,它表明一个类型可以随意被移动,那么问题来了,可以被 Pin 住的值,它有没有实现什么特征呢? 答案很出乎意料,可以被 Pin 住的值实现的特征是 !Unpin ,大家可能之前没有见过,但是它其实很简单,! 代表没有实现某个特征的意思,!Unpin 说明类型没有实现 Unpin 特征,那自然就可以被 Pin 了。
那是不是意味着类型如果实现了 Unpin 特征,就不能被 Pin 了?其实,还是可以 Pin 的,毕竟它只是一个结构体,你可以随意使用,但是不再有任何效果而已,该值一样可以被移动!
例如 Pin<&mut u8> ,显然 u8 实现了 Unpin 特征,它可以在内存中被移动,因此 Pin<&mut u8> 跟 &mut u8 实际上并无区别,一样可以被移动。
因此,一个类型如果不能被移动,它必须实现 !Unpin 特征。如果大家对 Pin 、 Unpin 还是模模糊糊,建议再重复看一遍之前的内容,理解它们对于我们后面要讲到的内容非常重要!
如果将 Unpin 与之前章节学过的 Send/Sync 进行下对比,会发现它们都很像:
- 都是标记特征( marker trait ),该特征未定义任何行为,非常适用于标记
- 都可以通过
!语法去除实现 - 绝大多数情况都是自动实现, 无需我们的操心
深入理解 Pin
对于上面的问题,我们可以简单的归结为如何在 Rust 中处理自引用类型(果然,只要是难点,都和自引用脱离不了关系),下面用一个稍微简单点的例子来理解下 Pin :
#![allow(unused)] fn main() { #[derive(Debug)] struct Test { a: String, b: *const String, } impl Test { fn new(txt: &str) -> Self { Test { a: String::from(txt), b: std::ptr::null(), } } fn init(&mut self) { let self_ref: *const String = &self.a; self.b = self_ref; } fn a(&self) -> &str { &self.a } fn b(&self) -> &String { assert!(!self.b.is_null(), "Test::b called without Test::init being called first"); unsafe { &*(self.b) } } } }
Test 提供了方法用于获取字段 a 和 b 的值的引用。这里b 是 a 的一个引用,但是我们并没有使用引用类型而是用了裸指针,原因是:Rust 的借用规则不允许我们这样用,因为不符合生命周期的要求。 此时的 Test 就是一个自引用结构体。
如果不移动任何值,那么上面的例子将没有任何问题,例如:
fn main() { let mut test1 = Test::new("test1"); test1.init(); let mut test2 = Test::new("test2"); test2.init(); println!("a: {}, b: {}", test1.a(), test1.b()); println!("a: {}, b: {}", test2.a(), test2.b()); }
输出非常正常:
a: test1, b: test1
a: test2, b: test2
明知山有虎,偏向虎山行,这才是我辈年轻人的风华。既然移动数据会导致指针不合法,那我们就移动下数据试试,将 test1 和 test2 进行下交换:
fn main() { let mut test1 = Test::new("test1"); test1.init(); let mut test2 = Test::new("test2"); test2.init(); println!("a: {}, b: {}", test1.a(), test1.b()); std::mem::swap(&mut test1, &mut test2); println!("a: {}, b: {}", test2.a(), test2.b()); }
按理来说,这样修改后,输出应该如下:
#![allow(unused)] fn main() { a: test1, b: test1 a: test1, b: test1 }
但是实际运行后,却产生了下面的输出:
#![allow(unused)] fn main() { a: test1, b: test1 a: test1, b: test2 }
原因是 test2.b 指针依然指向了旧的地址,而该地址对应的值现在在 test1 里,最终会打印出意料之外的值。
如果大家还是将信将疑,那再看看下面的代码:
fn main() { let mut test1 = Test::new("test1"); test1.init(); let mut test2 = Test::new("test2"); test2.init(); println!("a: {}, b: {}", test1.a(), test1.b()); std::mem::swap(&mut test1, &mut test2); test1.a = "I've totally changed now!".to_string(); println!("a: {}, b: {}", test2.a(), test2.b()); }
下面的图片也可以帮助更好的理解这个过程:
Pin 在实践中的运用
在理解了 Pin 的作用后,我们再来看看它怎么帮我们解决问题。
将值固定到栈上
回到之前的例子,我们可以用 Pin 来解决指针指向的数据被移动的问题:
#![allow(unused)] fn main() { use std::pin::Pin; use std::marker::PhantomPinned; #[derive(Debug)] struct Test { a: String, b: *const String, _marker: PhantomPinned, } impl Test { fn new(txt: &str) -> Self { Test { a: String::from(txt), b: std::ptr::null(), _marker: PhantomPinned, // 这个标记可以让我们的类型自动实现特征`!Unpin` } } fn init(self: Pin<&mut Self>) { let self_ptr: *const String = &self.a; let this = unsafe { self.get_unchecked_mut() }; this.b = self_ptr; } fn a(self: Pin<&Self>) -> &str { &self.get_ref().a } fn b(self: Pin<&Self>) -> &String { assert!(!self.b.is_null(), "Test::b called without Test::init being called first"); unsafe { &*(self.b) } } } }
上面代码中,我们使用了一个标记类型 PhantomPinned 将自定义结构体 Test 变成了 !Unpin (编译器会自动帮我们实现),因此该结构体无法再被移动。
一旦类型实现了 !Unpin ,那将它的值固定到栈( stack )上就是不安全的行为,因此在代码中我们使用了 unsafe 语句块来进行处理,你也可以使用 pin_utils 来避免 unsafe 的使用。
BTW, Rust 中的 unsafe 其实没有那么可怕,虽然听上去很不安全,但是实际上 Rust 依然提供了很多机制来帮我们提升了安全性,因此不必像对待 Go 语言的
unsafe那样去畏惧于使用 Rust 中的unsafe,大致使用原则总结如下:没必要用时,就不要用,当有必要用时,就大胆用,但是尽量控制好边界,让unsafe的范围尽可能小
此时,再去尝试移动被固定的值,就会导致编译错误 :
pub fn main() { // 此时的`test1`可以被安全的移动 let mut test1 = Test::new("test1"); // 新的`test1`由于使用了`Pin`,因此无法再被移动,这里的声明会将之前的`test1`遮蔽掉(shadow) let mut test1 = unsafe { Pin::new_unchecked(&mut test1) }; Test::init(test1.as_mut()); let mut test2 = Test::new("test2"); let mut test2 = unsafe { Pin::new_unchecked(&mut test2) }; Test::init(test2.as_mut()); println!("a: {}, b: {}", Test::a(test1.as_ref()), Test::b(test1.as_ref())); std::mem::swap(test1.get_mut(), test2.get_mut()); println!("a: {}, b: {}", Test::a(test2.as_ref()), Test::b(test2.as_ref())); }
注意到之前的粗体字了吗?是的,Rust 并不是在运行时做这件事,而是在编译期就完成了,因此没有额外的性能开销!来看看报错:
error[E0277]: `PhantomPinned` cannot be unpinned
--> src/main.rs:47:43
|
47 | std::mem::swap(test1.get_mut(), test2.get_mut());
| ^^^^^^^ within `Test`, the trait `Unpin` is not implemented for `PhantomPinned`
需要注意的是固定在栈上非常依赖于你写出的
unsafe代码的正确性。我们知道&'a mut T可以固定的生命周期是'a,但是我们却不知道当生命周期'a结束后,该指针指向的数据是否会被移走。如果你的unsafe代码里这么实现了,那么就会违背Pin应该具有的作用!一个常见的错误就是忘记去遮蔽(shadow )初始的变量,因为你可以
drop掉Pin,然后在&'a mut T结束后去移动数据:fn main() { let mut test1 = Test::new("test1"); let mut test1_pin = unsafe { Pin::new_unchecked(&mut test1) }; Test::init(test1_pin.as_mut()); drop(test1_pin); println!(r#"test1.b points to "test1": {:?}..."#, test1.b); let mut test2 = Test::new("test2"); mem::swap(&mut test1, &mut test2); println!("... and now it points nowhere: {:?}", test1.b); } use std::pin::Pin; use std::marker::PhantomPinned; use std::mem; #[derive(Debug)] struct Test { a: String, b: *const String, _marker: PhantomPinned, } impl Test { fn new(txt: &str) -> Self { Test { a: String::from(txt), b: std::ptr::null(), // This makes our type `!Unpin` _marker: PhantomPinned, } } fn init<'a>(self: Pin<&'a mut Self>) { let self_ptr: *const String = &self.a; let this = unsafe { self.get_unchecked_mut() }; this.b = self_ptr; } fn a<'a>(self: Pin<&'a Self>) -> &'a str { &self.get_ref().a } fn b<'a>(self: Pin<&'a Self>) -> &'a String { assert!(!self.b.is_null(), "Test::b called without Test::init being called first"); unsafe { &*(self.b) } } }
固定到堆上
将一个 !Unpin 类型的值固定到堆上,会给予该值一个稳定的内存地址,它指向的堆中的值在 Pin 后是无法被移动的。而且与固定在栈上不同,我们知道堆上的值在整个生命周期内都会被稳稳地固定住。
use std::pin::Pin; use std::marker::PhantomPinned; #[derive(Debug)] struct Test { a: String, b: *const String, _marker: PhantomPinned, } impl Test { fn new(txt: &str) -> Pin<Box<Self>> { let t = Test { a: String::from(txt), b: std::ptr::null(), _marker: PhantomPinned, }; let mut boxed = Box::pin(t); let self_ptr: *const String = &boxed.as_ref().a; unsafe { boxed.as_mut().get_unchecked_mut().b = self_ptr }; boxed } fn a(self: Pin<&Self>) -> &str { &self.get_ref().a } fn b(self: Pin<&Self>) -> &String { unsafe { &*(self.b) } } } pub fn main() { let test1 = Test::new("test1"); let test2 = Test::new("test2"); println!("a: {}, b: {}",test1.as_ref().a(), test1.as_ref().b()); println!("a: {}, b: {}",test2.as_ref().a(), test2.as_ref().b()); }
将固定住的 Future 变为 Unpin
之前的章节我们有提到 async 函数返回的 Future 默认就是 !Unpin 的。
但是,在实际应用中,一些函数会要求它们处理的 Future 是 Unpin 的,此时,若你使用的 Future 是 !Unpin 的,必须要使用以下的方法先将 Future 进行固定:
Box::pin, 创建一个Pin<Box<T>>pin_utils::pin_mut!, 创建一个Pin<&mut T>
固定后获得的 Pin<Box<T>> 和 Pin<&mut T> 既可以用于 Future ,又会自动实现 Unpin。
#![allow(unused)] fn main() { use pin_utils::pin_mut; // `pin_utils` 可以在crates.io中找到 // 函数的参数是一个`Future`,但是要求该`Future`实现`Unpin` fn execute_unpin_future(x: impl Future<Output = ()> + Unpin) { /* ... */ } let fut = async { /* ... */ }; // 下面代码报错: 默认情况下,`fut` 实现的是`!Unpin`,并没有实现`Unpin` // execute_unpin_future(fut); // 使用`Box`进行固定 let fut = async { /* ... */ }; let fut = Box::pin(fut); execute_unpin_future(fut); // OK // 使用`pin_mut!`进行固定 let fut = async { /* ... */ }; pin_mut!(fut); execute_unpin_future(fut); // OK }
总结
相信大家看到这里,脑袋里已经快被 Pin 、 Unpin 、 !Unpin 整爆炸了,没事,我们再来火上浇油下:)
- 若
T: Unpin( Rust 类型的默认实现),那么Pin<'a, T>跟&'a mut T完全相同,也就是Pin将没有任何效果, 该移动还是照常移动 - 绝大多数标准库类型都实现了
Unpin,事实上,对于 Rust 中你能遇到的绝大多数类型,该结论依然成立 ,其中一个例外就是:async/await生成的Future没有实现Unpin - 你可以通过以下方法为自己的类型添加
!Unpin约束:- 使用文中提到的
std::marker::PhantomPinned - 使用
nightly版本下的feature flag
- 使用文中提到的
- 可以将值固定到栈上,也可以固定到堆上
- 将
!Unpin值固定到栈上需要使用unsafe - 将
!Unpin值固定到堆上无需unsafe,可以通过Box::pin来简单的实现
- 将
- 当固定类型
T: !Unpin时,你需要保证数据从被固定到被 drop 这段时期内,其内存不会变得非法或者被重用
async/await 和 Stream 流处理
在入门章节中,我们简单学习了该如何使用 async/.await, 同时在后面也了解了一些底层原理,现在是时候继续深入了。
async/.await是 Rust 语法的一部分,它在遇到阻塞操作时( 例如 IO )会让出当前线程的所有权而不是阻塞当前线程,这样就允许当前线程继续去执行其它代码,最终实现并发。
有两种方式可以使用async: async fn用于声明函数,async { ... }用于声明语句块,它们会返回一个实现 Future 特征的值:
#![allow(unused)] fn main() { // `foo()`返回一个`Future<Output = u8>`, // 当调用`foo().await`时,该`Future`将被运行,当调用结束后我们将获取到一个`u8`值 async fn foo() -> u8 { 5 } fn bar() -> impl Future<Output = u8> { // 下面的`async`语句块返回`Future<Output = u8>` async { let x: u8 = foo().await; x + 5 } } }
async 是懒惰的,直到被执行器 poll 或者 .await 后才会开始运行,其中后者是最常用的运行 Future 的方法。 当 .await 被调用时,它会尝试运行 Future 直到完成,但是若该 Future 进入阻塞,那就会让出当前线程的控制权。当 Future 后面准备再一次被运行时(例如从 socket 中读取到了数据),执行器会得到通知,并再次运行该 Future ,如此循环,直到完成。
以上过程只是一个简述,详细内容在底层探秘中已经被深入讲解过,因此这里不再赘述。
async 的生命周期
async fn 函数如果拥有引用类型的参数,那它返回的 Future 的生命周期就会被这些参数的生命周期所限制:
#![allow(unused)] fn main() { async fn foo(x: &u8) -> u8 { *x } // 上面的函数跟下面的函数是等价的: fn foo_expanded<'a>(x: &'a u8) -> impl Future<Output = u8> + 'a { async move { *x } } }
意味着 async fn 函数返回的 Future 必须满足以下条件: 当 x 依然有效时, 该 Future 就必须继续等待( .await ), 也就是说x 必须比 Future活得更久。
在一般情况下,在函数调用后就立即 .await 不会存在任何问题,例如foo(&x).await。但是,若 Future 被先存起来或发送到另一个任务或者线程,就可能存在问题了:
#![allow(unused)] fn main() { use std::future::Future; fn bad() -> impl Future<Output = u8> { let x = 5; borrow_x(&x) // ERROR: `x` does not live long enough } async fn borrow_x(x: &u8) -> u8 { *x } }
以上代码会报错,因为 x 的生命周期只到 bad 函数的结尾。 但是 Future 显然会活得更久:
error[E0597]: `x` does not live long enough
--> src/main.rs:4:14
|
4 | borrow_x(&x) // ERROR: `x` does not live long enough
| ---------^^-
| | |
| | borrowed value does not live long enough
| argument requires that `x` is borrowed for `'static`
5 | }
| - `x` dropped here while still borrowed
其中一个常用的解决方法就是将具有引用参数的 async fn 函数转变成一个具有 'static 生命周期的 Future 。 以上解决方法可以通过将参数和对 async fn 的调用放在同一个 async 语句块来实现:
#![allow(unused)] fn main() { use std::future::Future; async fn borrow_x(x: &u8) -> u8 { *x } fn good() -> impl Future<Output = u8> { async { let x = 5; borrow_x(&x).await } } }
如上所示,通过将参数移动到 async 语句块内, 我们将它的生命周期扩展到 'static, 并跟返回的 Future 保持了一致。
async move
async 允许我们使用 move 关键字来将环境中变量的所有权转移到语句块内,就像闭包那样,好处是你不再发愁该如何解决借用生命周期的问题,坏处就是无法跟其它代码实现对变量的共享:
#![allow(unused)] fn main() { // 多个不同的 `async` 语句块可以访问同一个本地变量,只要它们在该变量的作用域内执行 async fn blocks() { let my_string = "foo".to_string(); let future_one = async { // ... println!("{my_string}"); }; let future_two = async { // ... println!("{my_string}"); }; // 运行两个 Future 直到完成 let ((), ()) = futures::join!(future_one, future_two); } // 由于`async move`会捕获环境中的变量,因此只有一个`async move`语句块可以访问该变量, // 但是它也有非常明显的好处: 变量可以转移到返回的 Future 中,不再受借用生命周期的限制 fn move_block() -> impl Future<Output = ()> { let my_string = "foo".to_string(); async move { // ... println!("{my_string}"); } } }
当.await 遇见多线程执行器
需要注意的是,当使用多线程 Future 执行器( executor )时, Future 可能会在线程间被移动,因此 async 语句块中的变量必须要能在线程间传递。 至于 Future 会在线程间移动的原因是:它内部的任何.await都可能导致它被切换到一个新线程上去执行。
由于需要在多线程环境使用,意味着 Rc、 RefCell 、没有实现 Send 的所有权类型、没有实现 Sync 的引用类型,它们都是不安全的,因此无法被使用
需要注意!实际上它们还是有可能被使用的,只要在
.await调用期间,它们没有在作用域范围内。
类似的原因,在 .await 时使用普通的锁也不安全,例如 Mutex 。原因是,它可能会导致线程池被锁:当一个任务获取锁 A 后,若它将线程的控制权还给执行器,然后执行器又调度运行另一个任务,该任务也去尝试获取了锁 A ,结果当前线程会直接卡死,最终陷入死锁中。
因此,为了避免这种情况的发生,我们需要使用 futures 包下的锁 futures::lock 来替代 Mutex 完成任务。
Stream 流处理
Stream 特征类似于 Future 特征,但是前者在完成前可以生成多个值,这种行为跟标准库中的 Iterator 特征倒是颇为相似。
#![allow(unused)] fn main() { trait Stream { // Stream生成的值的类型 type Item; // 尝试去解析Stream中的下一个值, // 若无数据,返回`Poll::Pending`, 若有数据,返回 `Poll::Ready(Some(x))`, `Stream`完成则返回 `Poll::Ready(None)` fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>; } }
关于 Stream 的一个常见例子是消息通道( futures 包中的)的消费者 Receiver。每次有消息从 Send 端发送后,它都可以接收到一个 Some(val) 值, 一旦 Send 端关闭(drop),且消息通道中没有消息后,它会接收到一个 None 值。
#![allow(unused)] fn main() { async fn send_recv() { const BUFFER_SIZE: usize = 10; let (mut tx, mut rx) = mpsc::channel::<i32>(BUFFER_SIZE); tx.send(1).await.unwrap(); tx.send(2).await.unwrap(); drop(tx); // `StreamExt::next` 类似于 `Iterator::next`, 但是前者返回的不是值,而是一个 `Future<Output = Option<T>>`, // 因此还需要使用`.await`来获取具体的值 assert_eq!(Some(1), rx.next().await); assert_eq!(Some(2), rx.next().await); assert_eq!(None, rx.next().await); } }
迭代和并发
跟迭代器类似,我们也可以迭代一个 Stream。 例如使用map,filter,fold方法,以及它们的遇到错误提前返回的版本: try_map,try_filter,try_fold。
但是跟迭代器又有所不同,for 循环无法在这里使用,但是命令式风格的循环while let是可以用的,同时还可以使用next 和 try_next 方法:
#![allow(unused)] fn main() { async fn sum_with_next(mut stream: Pin<&mut dyn Stream<Item = i32>>) -> i32 { use futures::stream::StreamExt; // 引入 next let mut sum = 0; while let Some(item) = stream.next().await { sum += item; } sum } async fn sum_with_try_next( mut stream: Pin<&mut dyn Stream<Item = Result<i32, io::Error>>>, ) -> Result<i32, io::Error> { use futures::stream::TryStreamExt; // 引入 try_next let mut sum = 0; while let Some(item) = stream.try_next().await? { sum += item; } Ok(sum) } }
上面代码是一次处理一个值的模式,但是需要注意的是:如果你选择一次处理一个值的模式,可能会造成无法并发,这就失去了异步编程的意义。 因此,如果可以的话我们还是要选择从一个 Stream 并发处理多个值的方式,通过 for_each_concurrent 或 try_for_each_concurrent 方法来实现:
#![allow(unused)] fn main() { async fn jump_around( mut stream: Pin<&mut dyn Stream<Item = Result<u8, io::Error>>>, ) -> Result<(), io::Error> { use futures::stream::TryStreamExt; // 引入 `try_for_each_concurrent` const MAX_CONCURRENT_JUMPERS: usize = 100; stream.try_for_each_concurrent(MAX_CONCURRENT_JUMPERS, |num| async move { jump_n_times(num).await?; report_n_jumps(num).await?; Ok(()) }).await?; Ok(()) } }
使用join!和select!同时运行多个 Future
招数单一,杀伤力惊人,说的就是 .await ,但是光用它,还真做不到一招鲜吃遍天。比如我们该如何同时运行多个任务,而不是使用.await慢悠悠地排队完成。
join!
futures 包中提供了很多实用的工具,其中一个就是 join!宏, 它允许我们同时等待多个不同 Future 的完成,且可以并发地运行这些 Future。
先来看一个不是很给力的、使用.await的版本:
#![allow(unused)] fn main() { async fn enjoy_book_and_music() -> (Book, Music) { let book = enjoy_book().await; let music = enjoy_music().await; (book, music) } }
这段代码可以顺利运行,但是有一个很大的问题,就是必须先看完书后,才能听音乐。咱们以前,谁又不是那个摇头晃脑爱读书(耳朵里偷偷塞着耳机,听的正 high)的好学生呢?
要支持同时看书和听歌,有些人可能会凭空生成下面代码:
#![allow(unused)] fn main() { // WRONG -- 别这么做 async fn enjoy_book_and_music() -> (Book, Music) { let book_future = enjoy_book(); let music_future = enjoy_music(); (book_future.await, music_future.await) } }
看上去像模像样,嗯,在某些语言中也许可以,但是 Rust 不行。因为在某些语言中,Future一旦创建就开始运行,等到返回的时候,基本就可以同时结束并返回了。 但是 Rust 中的 Future 是惰性的,直到调用 .await 时,才会开始运行。而那两个 await 由于在代码中有先后顺序,因此它们是顺序运行的。
为了正确的并发运行两个 Future , 我们来试试 futures::join! 宏:
#![allow(unused)] fn main() { use futures::join; async fn enjoy_book_and_music() -> (Book, Music) { let book_fut = enjoy_book(); let music_fut = enjoy_music(); join!(book_fut, music_fut) } }
Duang,目标顺利达成。同时join!会返回一个元组,里面的值是对应的Future执行结束后输出的值。
如果希望同时运行一个数组里的多个异步任务,可以使用
futures::future::join_all方法
try_join!
由于join!必须等待它管理的所有 Future 完成后才能完成,如果你希望在某一个 Future 报错后就立即停止所有 Future 的执行,可以使用 try_join!,特别是当 Future 返回 Result 时:
#![allow(unused)] fn main() { use futures::try_join; async fn get_book() -> Result<Book, String> { /* ... */ Ok(Book) } async fn get_music() -> Result<Music, String> { /* ... */ Ok(Music) } async fn get_book_and_music() -> Result<(Book, Music), String> { let book_fut = get_book(); let music_fut = get_music(); try_join!(book_fut, music_fut) } }
有一点需要注意,传给 try_join! 的所有 Future 都必须拥有相同的错误类型。如果错误类型不同,可以考虑使用来自 futures::future::TryFutureExt 模块的 map_err和err_info方法将错误进行转换:
#![allow(unused)] fn main() { use futures::{ future::TryFutureExt, try_join, }; async fn get_book() -> Result<Book, ()> { /* ... */ Ok(Book) } async fn get_music() -> Result<Music, String> { /* ... */ Ok(Music) } async fn get_book_and_music() -> Result<(Book, Music), String> { let book_fut = get_book().map_err(|()| "Unable to get book".to_string()); let music_fut = get_music(); try_join!(book_fut, music_fut) } }
join!很好很强大,但是人无完人,J 无完 J, 它有一个很大的问题。
select!
join!只有等所有 Future 结束后,才能集中处理结果,如果你想同时等待多个 Future ,且任何一个 Future 结束后,都可以立即被处理,可以考虑使用 futures::select!:
#![allow(unused)] fn main() { use futures::{ future::FutureExt, // for `.fuse()` pin_mut, select, }; async fn task_one() { /* ... */ } async fn task_two() { /* ... */ } async fn race_tasks() { let t1 = task_one().fuse(); let t2 = task_two().fuse(); pin_mut!(t1, t2); select! { () = t1 => println!("任务1率先完成"), () = t2 => println!("任务2率先完成"), } } }
上面的代码会同时并发地运行 t1 和 t2, 无论两者哪个先完成,都会调用对应的 println! 打印相应的输出,然后函数结束且不会等待另一个任务的完成。
但是,在实际项目中,我们往往需要等待多个任务都完成后,再结束,像上面这种其中一个任务结束就立刻结束的场景着实不多。
default 和 complete
select!还支持 default 和 complete 分支:
complete分支当所有的Future和Stream完成后才会被执行,它往往配合loop使用,loop用于循环完成所有的Futuredefault分支,若没有任何Future或Stream处于Ready状态, 则该分支会被立即执行
use futures::future; use futures::select; pub fn main() { let mut a_fut = future::ready(4); let mut b_fut = future::ready(6); let mut total = 0; loop { select! { a = a_fut => total += a, b = b_fut => total += b, complete => break, default => panic!(), // 该分支永远不会运行,因为`Future`会先运行,然后是`complete` }; } assert_eq!(total, 10); }
以上代码 default 分支由于最后一个运行,而在它之前 complete 分支已经通过 break 跳出了循环,因此default永远不会被执行。
如果你希望 default 也有机会露下脸,可以将 complete 的 break 修改为其它的,例如println!("completed!"),然后再观察下运行结果。
再回到 select 的第一个例子中,里面有一段代码长这样:
#![allow(unused)] fn main() { let t1 = task_one().fuse(); let t2 = task_two().fuse(); pin_mut!(t1, t2); }
当时没有展开讲,相信大家也有疑惑,下面我们来一起看看。
跟 Unpin 和 FusedFuture 进行交互
首先,.fuse()方法可以让 Future 实现 FusedFuture 特征, 而 pin_mut! 宏会为 Future 实现 Unpin特征,这两个特征恰恰是使用 select 所必须的:
Unpin,由于select不会通过拿走所有权的方式使用Future,而是通过可变引用的方式去使用,这样当select结束后,该Future若没有被完成,它的所有权还可以继续被其它代码使用。FusedFuture的原因跟上面类似,当Future一旦完成后,那select就不能再对其进行轮询使用。Fuse意味着熔断,相当于Future一旦完成,再次调用poll会直接返回Poll::Pending。
只有实现了FusedFuture,select 才能配合 loop 一起使用。假如没有实现,就算一个 Future 已经完成了,它依然会被 select 不停的轮询执行。
Stream 稍有不同,它们使用的特征是 FusedStream。 通过.fuse()(也可以手动实现)实现了该特征的 Stream,对其调用.next() 或 .try_next()方法可以获取实现了FusedFuture特征的Future:
#![allow(unused)] fn main() { use futures::{ stream::{Stream, StreamExt, FusedStream}, select, }; async fn add_two_streams( mut s1: impl Stream<Item = u8> + FusedStream + Unpin, mut s2: impl Stream<Item = u8> + FusedStream + Unpin, ) -> u8 { let mut total = 0; loop { let item = select! { x = s1.next() => x, x = s2.next() => x, complete => break, }; if let Some(next_num) = item { total += next_num; } } total } }
在 select 循环中并发
一个很实用但又鲜为人知的函数是 Fuse::terminated() ,可以使用它构建一个空的 Future ,空自然没啥用,但是如果它能在后面再被填充呢?
考虑以下场景:当你要在select循环中运行一个任务,但是该任务却是在select循环内部创建时,上面的函数就非常好用了。
#![allow(unused)] fn main() { use futures::{ future::{Fuse, FusedFuture, FutureExt}, stream::{FusedStream, Stream, StreamExt}, pin_mut, select, }; async fn get_new_num() -> u8 { /* ... */ 5 } async fn run_on_new_num(_: u8) { /* ... */ } async fn run_loop( mut interval_timer: impl Stream<Item = ()> + FusedStream + Unpin, starting_num: u8, ) { let run_on_new_num_fut = run_on_new_num(starting_num).fuse(); let get_new_num_fut = Fuse::terminated(); pin_mut!(run_on_new_num_fut, get_new_num_fut); loop { select! { () = interval_timer.select_next_some() => { // 定时器已结束,若`get_new_num_fut`没有在运行,就创建一个新的 if get_new_num_fut.is_terminated() { get_new_num_fut.set(get_new_num().fuse()); } }, new_num = get_new_num_fut => { // 收到新的数字 -- 创建一个新的`run_on_new_num_fut`并丢弃掉旧的 run_on_new_num_fut.set(run_on_new_num(new_num).fuse()); }, // 运行 `run_on_new_num_fut` () = run_on_new_num_fut => {}, // 若所有任务都完成,直接 `panic`, 原因是 `interval_timer` 应该连续不断的产生值,而不是结束 //后,执行到 `complete` 分支 complete => panic!("`interval_timer` completed unexpectedly"), } } } }
当某个 Future 有多个拷贝都需要同时运行时,可以使用 FuturesUnordered 类型。下面的例子跟上个例子大体相似,但是它会将 run_on_new_num_fut 的每一个拷贝都运行到完成,而不是像之前那样一旦创建新的就终止旧的。
#![allow(unused)] fn main() { use futures::{ future::{Fuse, FusedFuture, FutureExt}, stream::{FusedStream, FuturesUnordered, Stream, StreamExt}, pin_mut, select, }; async fn get_new_num() -> u8 { /* ... */ 5 } async fn run_on_new_num(_: u8) -> u8 { /* ... */ 5 } // 使用从 `get_new_num` 获取的最新数字 来运行 `run_on_new_num` // // 每当计时器结束后,`get_new_num` 就会运行一次,它会立即取消当前正在运行的`run_on_new_num` , // 并且使用新返回的值来替换 async fn run_loop( mut interval_timer: impl Stream<Item = ()> + FusedStream + Unpin, starting_num: u8, ) { let mut run_on_new_num_futs = FuturesUnordered::new(); run_on_new_num_futs.push(run_on_new_num(starting_num)); let get_new_num_fut = Fuse::terminated(); pin_mut!(get_new_num_fut); loop { select! { () = interval_timer.select_next_some() => { // 定时器已结束,若`get_new_num_fut`没有在运行,就创建一个新的 if get_new_num_fut.is_terminated() { get_new_num_fut.set(get_new_num().fuse()); } }, new_num = get_new_num_fut => { // 收到新的数字 -- 创建一个新的`run_on_new_num_fut` (并没有像之前的例子那样丢弃掉旧值) run_on_new_num_futs.push(run_on_new_num(new_num)); }, // 运行 `run_on_new_num_futs`, 并检查是否有已经完成的 res = run_on_new_num_futs.select_next_some() => { println!("run_on_new_num_fut returned {:?}", res); }, // 若所有任务都完成,直接 `panic`, 原因是 `interval_timer` 应该连续不断的产生值,而不是结束 //后,执行到 `complete` 分支 complete => panic!("`interval_timer` completed unexpectedly"), } } } }
一些疑难问题的解决办法
async 在 Rust 依然比较新,疑难杂症少不了,而它们往往还处于活跃开发状态,短时间内无法被解决,因此才有了本文。下面一起来看看这些问题以及相应的临时解决方案。
在 async 语句块中使用 ?
async 语句块和 async fn 最大的区别就是前者无法显式的声明返回值,在大多数时候这都不是问题,但是当配合 ? 一起使用时,问题就有所不同:
async fn foo() -> Result<u8, String> { Ok(1) } async fn bar() -> Result<u8, String> { Ok(1) } pub fn main() { let fut = async { foo().await?; bar().await?; Ok(()) }; }
以上代码编译后会报错:
error[E0282]: type annotations needed
--> src/main.rs:14:9
|
11 | let fut = async {
| --- consider giving `fut` a type
...
14 | Ok(1)
| ^^ cannot infer type for type parameter `E` declared on the enum `Result`
原因在于编译器无法推断出 Result<T, E>中的 E 的类型, 而且编译器的提示consider giving `fut` a type你也别傻乎乎的相信,然后尝试半天,最后无奈放弃:目前还没有办法为 async 语句块指定返回类型。
既然编译器无法推断出类型,那咱就给它更多提示,可以使用 ::< ... > 的方式来增加类型注释:
#![allow(unused)] fn main() { let fut = async { foo().await?; bar().await?; Ok::<(), String>(()) // 在这一行进行显式的类型注释 }; }
给予类型注释后此时编译器就知道Result<T, E>中的 E 的类型是String,进而成功通过编译。
async 函数和 Send 特征
在多线程章节我们深入讲过 Send 特征对于多线程间数据传递的重要性,对于 async fn 也是如此,它返回的 Future 能否在线程间传递的关键在于 .await 运行过程中,作用域中的变量类型是否是 Send。
学到这里,相信大家已经很清楚Rc无法在多线程环境使用,原因就在于它并未实现 Send 特征,那咱就用它来做例子:
#![allow(unused)] fn main() { use std::rc::Rc; #[derive(Default)] struct NotSend(Rc<()>); }
事实上,未实现 Send 特征的变量可以出现在 async fn 语句块中:
async fn bar() {} async fn foo() { NotSend::default(); bar().await; } fn require_send(_: impl Send) {} fn main() { require_send(foo()); }
即使上面的 foo 返回的 Future 是 Send, 但是在它内部短暂的使用 NotSend 依然是安全的,原因在于它的作用域并没有影响到 .await,下面来试试声明一个变量,然后让 .await的调用处于变量的作用域中试试:
#![allow(unused)] fn main() { async fn foo() { let x = NotSend::default(); bar().await; } }
不出所料,错误如期而至:
error: future cannot be sent between threads safely
--> src/main.rs:17:18
|
17 | require_send(foo());
| ^^^^^ future returned by `foo` is not `Send`
|
= help: within `impl futures::Future<Output = ()>`, the trait `std::marker::Send` is not implemented for `Rc<()>`
note: future is not `Send` as this value is used across an await
--> src/main.rs:11:5
|
10 | let x = NotSend::default();
| - has type `NotSend` which is not `Send`
11 | bar().await;
| ^^^^^^^^^^^ await occurs here, with `x` maybe used later
12 | }
| - `x` is later dropped here
提示很清晰,.await在运行时处于 x 的作用域内。在之前章节有提到过, .await 有可能被执行器调度到另一个线程上运行,而 Rc 并没有实现 Send,因此编译器无情拒绝了咱们。
其中一个可能的解决方法是在 .await 之前就使用 std::mem::drop 释放掉 Rc,但是很可惜,截止今天,该方法依然不能解决这种问题。
不知道有多少同学还记得语句块 { ... } 在 Rust 中其实具有非常重要的作用(特别是相比其它大多数语言来说时):可以将变量声明在语句块内,当语句块结束时,变量会自动被 Drop,这个规则可以帮助我们解决很多借用冲突问题,特别是在 NLL 出来之前。
#![allow(unused)] fn main() { async fn foo() { { let x = NotSend::default(); } bar().await; } }
是不是很简单?最终我们还是通过 Drop 的方式解决了这个问题,当然,还是期待未来 std::mem::drop 也能派上用场。
递归使用 async fn
在内部实现中,async fn被编译成一个状态机,这会导致递归使用 async fn 变得较为复杂, 因为编译后的状态机还需要包含自身。
#![allow(unused)] fn main() { // foo函数: async fn foo() { step_one().await; step_two().await; } // 会被编译成类似下面的类型: enum Foo { First(StepOne), Second(StepTwo), } // 因此recursive函数 async fn recursive() { recursive().await; recursive().await; } // 会生成类似以下的类型 enum Recursive { First(Recursive), Second(Recursive), } }
这是典型的动态大小类型,它的大小会无限增长,因此编译器会直接报错:
error[E0733]: recursion in an `async fn` requires boxing
--> src/lib.rs:1:22
|
1 | async fn recursive() {
| ^ an `async fn` cannot invoke itself directly
|
= note: a recursive `async fn` must be rewritten to return a boxed future.
如果认真学过之前的章节,大家应该知道只要将其使用 Box 放到堆上而不是栈上,就可以解决,在这里还是要称赞下 Rust 的编译器,给出的提示总是这么精确recursion in an `async fn` requires boxing。
就算是使用 Box,这里也大有讲究。如果我们试图使用 Box::pin 这种方式去包裹是不行的,因为编译器自身的限制限制了我们(刚夸过它。。。)。为了解决这种问题,我们只能将 recursive 转变成一个正常的函数,该函数返回一个使用 Box 包裹的 async 语句块:
#![allow(unused)] fn main() { use futures::future::{BoxFuture, FutureExt}; fn recursive() -> BoxFuture<'static, ()> { async move { recursive().await; recursive().await; }.boxed() } }
在特征中使用 async
在目前版本中,我们还无法在特征中定义 async fn 函数,不过大家也不用担心,目前已经有计划在未来移除这个限制了。
#![allow(unused)] fn main() { trait Test { async fn test(); } }
运行后报错:
error[E0706]: functions in traits cannot be declared `async`
--> src/main.rs:5:5
|
5 | async fn test();
| -----^^^^^^^^^^^
| |
| `async` because of this
|
= note: `async` trait functions are not currently supported
= note: consider using the `async-trait` crate: https://crates.io/crates/async-trait
好在编译器给出了提示,让我们使用 async-trait 解决这个问题:
#![allow(unused)] fn main() { use async_trait::async_trait; #[async_trait] trait Advertisement { async fn run(&self); } struct Modal; #[async_trait] impl Advertisement for Modal { async fn run(&self) { self.render_fullscreen().await; for _ in 0..4u16 { remind_user_to_join_mailing_list().await; } self.hide_for_now().await; } } struct AutoplayingVideo { media_url: String, } #[async_trait] impl Advertisement for AutoplayingVideo { async fn run(&self) { let stream = connect(&self.media_url).await; stream.play().await; // 用视频说服用户加入我们的邮件列表 Modal.run().await; } } }
不过使用该包并不是免费的,每一次特征中的async函数被调用时,都会产生一次堆内存分配。对于大多数场景,这个性能开销都可以接受,但是当函数一秒调用几十万、几百万次时,就得小心这块儿代码的性能了!
一个实践项目: Web 服务器
知识学得再多,不实际应用也是纸上谈兵,不是忘掉就是废掉,对于技术学习尤为如此。在之前章节中,我们已经学习了 Async Rust 的方方面面,现在来将这些知识融会贯通,最终实现一个并发 Web 服务器。
多线程版本的 Web 服务器
在正式开始前,先来看一个单线程版本的 Web 服务器,该例子来源于 Rust Book 一书。
src/main.rs:
use std::fs; use std::io::prelude::*; use std::net::TcpListener; use std::net::TcpStream; fn main() { // 监听本地端口 7878 ,等待 TCP 连接的建立 let listener = TcpListener::bind("127.0.0.1:7878").unwrap(); // 阻塞等待请求的进入 for stream in listener.incoming() { let stream = stream.unwrap(); handle_connection(stream); } } fn handle_connection(mut stream: TcpStream) { // 从连接中顺序读取 1024 字节数据 let mut buffer = [0; 1024]; stream.read(&mut buffer).unwrap(); let get = b"GET / HTTP/1.1\r\n"; // 处理HTTP协议头,若不符合则返回404和对应的`html`文件 let (status_line, filename) = if buffer.starts_with(get) { ("HTTP/1.1 200 OK\r\n\r\n", "hello.html") } else { ("HTTP/1.1 404 NOT FOUND\r\n\r\n", "404.html") }; let contents = fs::read_to_string(filename).unwrap(); // 将回复内容写入连接缓存中 let response = format!("{status_line}{contents}"); stream.write_all(response.as_bytes()).unwrap(); // 使用flush将缓存中的内容发送到客户端 stream.flush().unwrap(); }
hello.html:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<title>Hello!</title>
</head>
<body>
<h1>Hello!</h1>
<p>Hi from Rust</p>
</body>
</html>
404.html:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<title>Hello!</title>
</head>
<body>
<h1>Oops!</h1>
<p>Sorry, I don't know what you're asking for.</p>
</body>
</html>
运行以上代码,并从浏览器访问 127.0.0.1:7878 你将看到一条来自 Ferris 的问候。
在回忆了单线程版本该如何实现后,我们也将进入正题,一起来实现一个基于 async 的异步 Web 服务器。
运行异步代码
一个 Web 服务器必须要能并发的处理大量来自用户的请求,也就是我们不能在处理完上一个用户的请求后,再处理下一个用户的请求。上面的单线程版本可以修改为多线程甚至于线程池来实现并发处理,但是线程还是太重了,使用 async 实现 Web 服务器才是最适合的。
首先将 handle_connection 修改为 async 实现:
#![allow(unused)] fn main() { async fn handle_connection(mut stream: TcpStream) { //<-- snip --> } }
该修改会将函数的返回值从 () 变成 Future<Output=()> ,因此直接运行将不再有任何效果,只用通过.await或执行器的poll调用后才能获取 Future 的结果。
在之前的代码中,我们使用了自己实现的简单的执行器来进行.await 或 poll ,实际上这只是为了学习原理,在实际项目中,需要选择一个三方的 async 运行时来实现相关的功能。 具体的选择我们将在下一章节进行讲解,现在先选择 async-std ,该包的最大优点就是跟标准库的 API 类似,相对来说更简单易用。
使用 async-std 作为异步运行时
下面的例子将演示如何使用一个异步运行时async-std来让之前的 async fn 函数运行起来,该运行时允许使用属性 #[async_std::main] 将我们的 fn main 函数变成 async fn main ,这样就可以在 main 函数中直接调用其它 async 函数,否则你得用之前章节的 block_on 方法来让 main 去阻塞等待异步函数的完成,但是这种简单粗暴的阻塞等待方式并不灵活。
修改 Cargo.toml 添加 async-std 包并开启相应的属性:
[dependencies]
futures = "0.3"
[dependencies.async-std]
version = "1.6"
features = ["attributes"]
下面将 main 函数修改为异步的,并在其中调用前面修改的异步版本 handle_connection :
#[async_std::main] async fn main() { let listener = TcpListener::bind("127.0.0.1:7878").unwrap(); for stream in listener.incoming() { let stream = stream.unwrap(); // 警告,这里无法并发 handle_connection(stream).await; } }
上面的代码虽然已经是异步的,实际上它还无法并发,原因我们后面会解释,先来模拟一下慢请求:
#![allow(unused)] fn main() { use async_std::task; async fn handle_connection(mut stream: TcpStream) { let mut buffer = [0; 1024]; stream.read(&mut buffer).unwrap(); let get = b"GET / HTTP/1.1\r\n"; let sleep = b"GET /sleep HTTP/1.1\r\n"; let (status_line, filename) = if buffer.starts_with(get) { ("HTTP/1.1 200 OK\r\n\r\n", "hello.html") } else if buffer.starts_with(sleep) { task::sleep(Duration::from_secs(5)).await; ("HTTP/1.1 200 OK\r\n\r\n", "hello.html") } else { ("HTTP/1.1 404 NOT FOUND\r\n\r\n", "404.html") }; let contents = fs::read_to_string(filename).unwrap(); let response = format!("{status_line}{contents}"); stream.write(response.as_bytes()).unwrap(); stream.flush().unwrap(); } }
上面是全新实现的 handle_connection ,它会在内部睡眠 5 秒,模拟一次用户慢请求,需要注意的是,我们并没有使用 std::thread::sleep 进行睡眠,原因是该函数是阻塞的,它会让当前线程陷入睡眠中,导致其它任务无法继续运行!因此我们需要一个睡眠函数 async_std::task::sleep,它仅会让当前的任务陷入睡眠,然后该任务会让出线程的控制权,这样线程就可以继续运行其它任务。
因此,光把函数变成 async 往往是不够的,还需要将它内部的代码也都变成异步兼容的,阻塞线程绝对是不可行的。
现在运行服务器,并访问 127.0.0.1:7878/sleep, 你会发现只有在完成第一个用户请求(5 秒后),才能开始处理第二个用户请求。现在再来看看该如何解决这个问题,让请求并发起来。
并发地处理连接
上面代码最大的问题是 listener.incoming() 是阻塞的迭代器。当 listener 在等待连接时,执行器是无法执行其它Future的,而且只有在我们处理完已有的连接后,才能接收新的连接。
解决方法是将 listener.incoming() 从一个阻塞的迭代器变成一个非阻塞的 Stream, 后者在前面章节有过专门介绍:
use async_std::net::TcpListener; use async_std::net::TcpStream; use futures::stream::StreamExt; #[async_std::main] async fn main() { let listener = TcpListener::bind("127.0.0.1:7878").await.unwrap(); listener .incoming() .for_each_concurrent(/* limit */ None, |tcpstream| async move { let tcpstream = tcpstream.unwrap(); handle_connection(tcpstream).await; }) .await; }
异步版本的 TcpListener 为 listener.incoming() 实现了 Stream 特征,以上修改有两个好处:
listener.incoming()不再阻塞- 使用
for_each_concurrent并发地处理从Stream获取的元素
现在上面的实现的关键在于 handle_connection 不能再阻塞:
#![allow(unused)] fn main() { use async_std::prelude::*; async fn handle_connection(mut stream: TcpStream) { let mut buffer = [0; 1024]; stream.read(&mut buffer).await.unwrap(); //<-- snip --> stream.write(response.as_bytes()).await.unwrap(); stream.flush().await.unwrap(); } }
在将数据读写改造成异步后,现在该函数也彻底变成了异步的版本,因此一次慢请求不再会阻止其它请求的运行。
使用多线程并行处理请求
聪明的读者不知道有没有发现,之前的例子有一个致命的缺陷:只能使用一个线程并发的处理用户请求。是的,这样也可以实现并发,一秒处理几千次请求问题不大,但是这毕竟没有利用上 CPU 的多核并行能力,无法实现性能最大化。
async 并发和多线程其实并不冲突,而 async-std 包也允许我们使用多个线程去处理,由于 handle_connection 实现了 Send 特征且不会阻塞,因此使用 async_std::task::spawn 是非常安全的:
use async_std::task::spawn; #[async_std::main] async fn main() { let listener = TcpListener::bind("127.0.0.1:7878").await.unwrap(); listener .incoming() .for_each_concurrent(/* limit */ None, |stream| async move { let stream = stream.unwrap(); spawn(handle_connection(stream)); }) .await; }
至此,我们实现了同时使用并行(多线程)和并发( async )来同时处理多个请求!
测试 handle_connection 函数
对于测试 Web 服务器,使用集成测试往往是最简单的,但是在本例子中,将使用单元测试来测试连接处理函数的正确性。
为了保证单元测试的隔离性和确定性,我们使用 MockTcpStream 来替代 TcpStream 。首先,修改 handle_connection 的函数签名让测试更简单,之所以可以修改签名,原因在于 async_std::net::TcpStream 实际上并不是必须的,只要任何结构体实现了 async_std::io::Read, async_std::io::Write 和 marker::Unpin 就可以替代它:
#![allow(unused)] fn main() { use std::marker::Unpin; use async_std::io::{Read, Write}; async fn handle_connection(mut stream: impl Read + Write + Unpin) { }
下面,来构建一个 mock 的 TcpStream 并实现了上面这些特征,它包含一些数据,这些数据将被拷贝到 read 缓存中, 然后返回 Poll::Ready 说明 read 已经结束:
#![allow(unused)] fn main() { use super::*; use futures::io::Error; use futures::task::{Context, Poll}; use std::cmp::min; use std::pin::Pin; struct MockTcpStream { read_data: Vec<u8>, write_data: Vec<u8>, } impl Read for MockTcpStream { fn poll_read( self: Pin<&mut Self>, _: &mut Context, buf: &mut [u8], ) -> Poll<Result<usize, Error>> { let size: usize = min(self.read_data.len(), buf.len()); buf[..size].copy_from_slice(&self.read_data[..size]); Poll::Ready(Ok(size)) } } }
Write的实现也类似,需要实现三个方法 : poll_write, poll_flush, 与 poll_close。 poll_write 会拷贝输入数据到 mock 的 TcpStream 中,当完成后返回 Poll::Ready。由于 TcpStream 无需 flush 和 close,因此另两个方法直接返回 Poll::Ready 即可。
#![allow(unused)] fn main() { impl Write for MockTcpStream { fn poll_write( mut self: Pin<&mut Self>, _: &mut Context, buf: &[u8], ) -> Poll<Result<usize, Error>> { self.write_data = Vec::from(buf); Poll::Ready(Ok(buf.len())) } fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Error>> { Poll::Ready(Ok(())) } fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Error>> { Poll::Ready(Ok(())) } } }
最后,我们的 mock 需要实现 Unpin 特征,表示它可以在内存中安全的移动,具体内容在前面章节有讲。
#![allow(unused)] fn main() { use std::marker::Unpin; impl Unpin for MockTcpStream {} }
现在可以准备开始测试了,在使用初始化数据设置好 MockTcpStream 后,我们可以使用 #[async_std::test] 来运行 handle_connection 函数,该函数跟 #[async_std::main] 的作用类似。为了确保 handle_connection 函数正确工作,需要根据初始化数据检查正确的数据被写入到 MockTcpStream 中。
#![allow(unused)] fn main() { use std::fs; #[async_std::test] async fn test_handle_connection() { let input_bytes = b"GET / HTTP/1.1\r\n"; let mut contents = vec![0u8; 1024]; contents[..input_bytes.len()].clone_from_slice(input_bytes); let mut stream = MockTcpStream { read_data: contents, write_data: Vec::new(), }; handle_connection(&mut stream).await; let mut buf = [0u8; 1024]; stream.read(&mut buf).await.unwrap(); let expected_contents = fs::read_to_string("hello.html").unwrap(); let expected_response = format!("HTTP/1.1 200 OK\r\n\r\n{}", expected_contents); assert!(stream.write_data.starts_with(expected_response.as_bytes())); } }
Tokio 使用指南
在上一个章节中,我们提到了 Rust 异步编程的限制,其中之一就是你必须引入社区提供的异步运行时,其中最有名的就是 tokio。
在本章中,我们一起来看看 tokio 到底有什么优势,以及该如何使用它。
本章在内容上大量借鉴和翻译了 tokio 官方文档Tokio Tutorial, 但是重新组织了内容形式并融入了很多自己的见解和感悟,给大家提供更好的可读性和知识扩展性
tokio 概览
对于 Async Rust,最最重要的莫过于底层的异步运行时,它提供了执行器、任务调度、异步 API 等核心服务。简单来说,使用 Rust 提供的 async/await 特性编写的异步代码要运行起来,就必须依赖于异步运行时,否则这些代码将毫无用处。
异步运行时
Rust 语言本身只提供了异步编程所需的基本特性,例如 async/await 关键字,标准库中的 Future 特征,官方提供的 futures 实用库,这些特性单独使用没有任何用处,因此我们需要一个运行时来将这些特性实现的代码运行起来。
异步运行时是由 Rust 社区提供的,它们的核心是一个 reactor 和一个或多个 executor(执行器):
reactor用于提供外部事件的订阅机制,例如I/O、进程间通信、定时器等executor在上一章我们有过深入介绍,它用于调度和执行相应的任务(Future)
目前最受欢迎的几个运行时有:
tokio,目前最受欢迎的异步运行时,功能强大,还提供了异步所需的各种工具(例如 tracing )、网络协议框架(例如 HTTP,gRPC )等等async-std,最大的优点就是跟标准库兼容性较强smol, 一个小巧的异步运行时
但是,大浪淘沙,留下的才是金子,随着时间的流逝,tokio越来越亮眼,无论是性能、功能还是社区、文档,它在各个方面都异常优秀,时至今日,可以说已成为事实上的标准。
异步运行时的兼容性
为何选择异步运行时这么重要?不仅仅是它们在功能、性能上存在区别,更重要的是当你选择了一个,往往就无法切换到另外一个,除非异步代码很少。
使用异步运行时,往往伴随着对它相关的生态系统的深入使用,因此耦合性会越来越强,直至最后你很难切换到另一个运行时,例如 tokio 和 async-std ,就存在这种问题。
如果你实在有这种需求,可以考虑使用 async-compat,该包提供了一个中间层,用于兼容 tokio 和其它运行时。
结论
相信大家看到现在,心中应该有一个结论了。首先,运行时之间的不兼容性,让我们必须提前选择一个运行时,并且在未来坚持用下去,那这个运行时就应该是最优秀、最成熟的那个,tokio 几乎成了不二选择,当然 tokio 也有自己的问题:更难上手和运行时之间的兼容性。
如果你只用 tokio ,那兼容性自然不是问题,至于难以上手,Rust 这么难,我们都学到现在了,何况区区一个异步运行时,在本书的帮助下,这些都不再是问题:)
tokio 简介
tokio 是一个纸醉金迷之地,只要有钱就可以为所欲为,哦,抱歉,走错片场了。tokio 是 Rust 最优秀的异步运行时框架,它提供了写异步网络服务所需的几乎所有功能,不仅仅适用于大型服务器,还适用于小型嵌入式设备,它主要由以下组件构成:
- 多线程版本的异步运行时,可以运行使用
async/await编写的代码 - 标准库中阻塞 API 的异步版本,例如
thread::sleep会阻塞当前线程,tokio中就提供了相应的异步实现版本 - 构建异步编程所需的生态,甚至还提供了
tracing用于日志和分布式追踪, 提供console用于 Debug 异步编程
优势
下面一起来看看使用 tokio 能给你提供哪些优势。
高性能
因为快所以快,前者是 Rust 快,后者是 tokio 快。 tokio 在编写时充分利用了 Rust 提供的各种零成本抽象和高性能特性,而且贯彻了 Rust 的牛逼思想:如果你选择手写代码,那么最好的结果就是跟 tokio 一样快!
以下是一张官方提供的性能参考图,大致能体现出 tokio 的性能之恐怖:

高可靠
Rust 语言的安全可靠性顺理成章的影响了 tokio 的可靠性,曾经有一个调查给出了令人乍舌的结论:软件系统 70%的高危漏洞都是由内存不安全性导致的。
在 Rust 提供的安全性之外,tokio 还致力于提供一致性的行为表现:无论你何时运行系统,它的预期表现和性能都是一致的,例如不会出现莫名其妙的请求延迟或响应时间大幅增加。
简单易用
通过 Rust 提供的 async/await 特性,编写异步程序的复杂性相比当初已经大幅降低,同时 tokio 还为我们提供了丰富的生态,进一步大幅降低了其复杂性。
同时 tokio 遵循了标准库的命名规则,让熟悉标准库的用户可以很快习惯于 tokio 的语法,再借助于 Rust 强大的类型系统,用户可以轻松地编写和交付正确的代码。
使用灵活性
tokio 支持你灵活的定制自己想要的运行时,例如你可以选择多线程 + 任务盗取模式的复杂运行时,也可以选择单线程的轻量级运行时。总之,几乎你的每一种需求在 tokio 中都能寻找到支持(画外音:强大的灵活性需要一定的复杂性来换取,并不是免费的午餐)。
劣势
虽然 tokio 对于大多数需要并发的项目都是非常适合的,但是确实有一些场景它并不适合使用:
- 并行运行 CPU 密集型的任务。
tokio非常适合于 IO 密集型任务,这些 IO 任务的绝大多数时间都用于阻塞等待 IO 的结果,而不是刷刷刷的单烤 CPU。如果你的应用是 CPU 密集型(例如并行计算),建议使用rayon,当然,对于其中的 IO 任务部分,你依然可以混用tokio - 读取大量的文件。读取文件的瓶颈主要在于操作系统,因为 OS 没有提供异步文件读取接口,大量的并发并不会提升文件读取的并行性能,反而可能会造成不可忽视的性能损耗,因此建议使用线程(或线程池)的方式
- 发送少量 HTTP 请求。
tokio的优势是给予你并发处理大量任务的能力,对于这种轻量级 HTTP 请求场景,tokio除了增加你的代码复杂性,并无法带来什么额外的优势。因此,对于这种场景,你可以使用reqwest库,它会更加简单易用。
若大家使用 tokio,那 CPU 密集的任务尤其需要用线程的方式去处理,例如使用
spawn_blocking创建一个阻塞的线程去完成相应 CPU 密集任务。原因是:tokio 是协作式的调度器,如果某个 CPU 密集的异步任务是通过 tokio 创建的,那理论上来说,该异步任务需要跟其它的异步任务交错执行,最终大家都得到了执行,皆大欢喜。但实际情况是,CPU 密集的任务很可能会一直霸着着 CPU,此时 tokio 的调度方式决定了该任务会一直被执行,这意味着,其它的异步任务无法得到执行的机会,最终这些任务都会因为得不到资源而饿死。
而使用
spawn_blocking后,会创建一个单独的 OS 线程,该线程并不会被 tokio 所调度( 被 OS 所调度 ),因此它所执行的 CPU 密集任务也不会导致 tokio 调度的那些异步任务被饿死
总结
离开三方开源社区提供的异步运行时, async/await 什么都不是,甚至还不如一堆破铜烂铁,除非你选择根据自己的需求手撸一个。
而 tokio 就是那颗皇冠上的夜明珠,也是值得我们投入时间去深入学习的开源库,它的设计原理和代码实现都异常优秀,在之后的章节中,我们将对其进行深入学习和剖析,敬请期待。
tokio 初印象
又到了喜闻乐见的初印象环节,这个环节决定了你心中的那 24 盏灯最终是全亮还是全灭。
在本文中,我们将看看本专题的学习目标、tokio该怎么引入以及如何实现一个 Hello Tokio 项目,最终亮灯还是灭灯的决定权留给各位看官。但我提前说好,如果你全灭了,但却找不到更好的,未来还是得回来真香 :P
专题目标
通过 API 学项目无疑是无聊的,因此我们采用一个与众不同的方式:边学边练,在本专题的最后你将拥有一个 redis 客户端和服务端,当然不会实现一个完整版本的 redis ,只会提供基本的功能和部分常用的命令。
mini-redis
redis 的项目源码可以在这里访问,本项目是从官方地址 fork 而来,在未来会提供注释和文档汉化。
再次声明:该项目仅仅用于学习目的,因此它的文档注释非常全,但是它完全无法作为 redis 的替代品。
环境配置
首先,我们假定你已经安装了 Rust 和相关的工具链,例如 cargo。其中 Rust 版本的最低要求是 1.45.0,建议使用最新版 1.58:
sunfei@sunface $ rustc --version
rustc 1.58.0 (02072b482 2022-01-11)
接下来,安装 mini-redis 的服务器端,它可以用来测试我们后面将要实现的 redis 客户端:
$ cargo install mini-redis
如果下载失败,也可以通过这个地址下载源码,然后在本地通过
cargo run运行。
下载成功后,启动服务端:
$ mini-redis-server
然后,再使用客户端测试下刚启动的服务端:
$ mini-redis-cli set foo 1
OK
$ mini-redis-cli get foo
"1"
不得不说,还挺好用的,先自我陶醉下 :) 此时,万事俱备,只欠东风,接下来是时候亮"箭"了:实现我们的 Hello Tokio 项目。
Hello Tokio
与简单无比的 Hello World 有所不同(简单?还记得本书开头时,湖畔边的那个多国语言版本的你好,世界嘛~~),Hello Tokio 它承载着"非常艰巨"的任务,那就是向刚启动的 redis 服务器写入一个 key=hello, value=world ,然后再读取出来,嗯,使用 mini-redis 客户端 :)
分析未到,代码先行
在详细讲解之前,我们先来看看完整的代码,让大家有一个直观的印象。首先,创建一个新的 Rust 项目:
$ cargo new my-redis
$ cd my-redis
然后在 Cargo.toml 中添加相关的依赖:
[dependencies]
tokio = { version = "1", features = ["full"] }
mini-redis = "0.4"
接下来,使用以下代码替换 main.rs 中的内容:
use mini_redis::{client, Result}; #[tokio::main] async fn main() -> Result<()> { // 建立与mini-redis服务器的连接 let mut client = client::connect("127.0.0.1:6379").await?; // 设置 key: "hello" 和 值: "world" client.set("hello", "world".into()).await?; // 获取"key=hello"的值 let result = client.get("hello").await?; println!("从服务器端获取到结果={:?}", result); Ok(()) }
不知道你之前启动的 mini-redis-server 关闭没有,如果关了,记得重新启动下,否则我们的代码就是意大利空气炮。
最后,运行这个项目:
$ cargo run
从服务器端获取到结果=Some("world")
Perfect, 代码成功运行,是时候来解释下其中蕴藏的至高奥秘了。
原理解释
代码篇幅虽然不长,但是还是有不少值得关注的地方,接下来我们一起来看看。
#![allow(unused)] fn main() { let mut client = client::connect("127.0.0.1:6379").await?; }
client::connect 函数由mini-redis 包提供,它使用异步的方式跟指定的远程 IP 地址建立 TCP 长连接,一旦连接建立成功,那 client 的赋值初始化也将完成。
特别值得注意的是:虽然该连接是异步建立的,但是从代码本身来看,完全是同步的代码编写方式,唯一能说明异步的点就是 .await。
什么是异步编程
大部分计算机程序都是按照代码编写的顺序来执行的:先执行第一行,然后第二行,以此类推(当然,还要考虑流程控制,例如循环)。当进行同步编程时,一旦程序遇到一个操作无法被立即完成,它就会进入阻塞状态,直到该操作完成为止。
因此同步编程非常符合我们人类的思维习惯,是一个顺其自然的过程,被几乎每一个程序员所喜欢(本来想说所有,但我不敢打包票,毕竟总有特立独行之士)。例如,当建立 TCP 连接时,当前线程会被阻塞,直到等待该连接建立完成,然后才往下继续进行。
而使用异步编程,无法立即完成的操作会被切到后台去等待,因此当前线程不会被阻塞,它会接着执行其它的操作。一旦之前的操作准备好可以继续执行后,它会通知执行器,然后执行器会调度它并从上次离开的点继续执行。但是大家想象下,如果没有使用 await,而是按照这个异步的流程使用通知 -> 回调的方式实现,代码该多么的难写和难读!
好在 Rust 为我们提供了 async/await 的异步编程特性,让我们可以像写同步代码那样去写异步的代码,也让这个世界美好依旧。
编译时绿色线程
一个函数可以通过async fn的方式被标记为异步函数:
#![allow(unused)] fn main() { use mini_redis::Result; use mini_redis::client::Client; use tokio::net::ToSocketAddrs; pub async fn connect<T: ToSocketAddrs>(addr: T) -> Result<Client> { // ... } }
在上例中,redis 的连接函数 connect 实现如上,它看上去很像是一个同步函数,但是 async fn 出卖了它。
async fn 异步函数并不会直接返回值,而是返回一个 Future,顾名思义,该 Future 会在未来某个时间点被执行,然后最终获取到真实的返回值 Result<Client>。
async/await 的原理就算大家不理解,也不妨碍使用
tokio写出能用的服务,但是如果想要更深入的用好,强烈建议认真读下本书的async/await异步编程章节,你会对 Rust 的异步编程有一个全新且深刻的认识。
由于 async 会返回一个 Future,因此我们还需要配合使用 .await 来让该 Future 运行起来,最终获得返回值:
async fn say_to_world() -> String { String::from("world") } #[tokio::main] async fn main() { // 此处的函数调用是惰性的,并不会执行 `say_to_world()` 函数体中的代码 let op = say_to_world(); // 首先打印出 "hello" println!("hello"); // 使用 `.await` 让 `say_to_world` 开始运行起来 println!("{}", op.await); }
上面代码输出如下:
hello
world
而大家可能很好奇 async fn 到底返回什么吧?它实际上返回的是一个实现了 Future 特征的匿名类型: impl Future<Output = String>。
async main
在代码中,使用了一个与众不同的 main 函数 : async fn main ,而且是用 #[tokio::main] 属性进行了标记。异步 main 函数有以下意义:
.await只能在async函数中使用,如果是以前的fn main,那它内部是无法直接使用async函数的!这个会极大的限制了我们的使用场景- 异步运行时本身需要初始化
因此 #[tokio::main] 宏在将 async fn main 隐式的转换为 fn main 的同时还对整个异步运行时进行了初始化。例如以下代码:
#[tokio::main] async fn main() { println!("hello"); }
将被转换成:
fn main() { let mut rt = tokio::runtime::Runtime::new().unwrap(); rt.block_on(async { println!("hello"); }) }
最终,Rust 编译器就愉快地执行这段代码了。
cargo feature
在引入 tokio 包时,我们在 Cargo.toml 文件中添加了这么一行:
tokio = { version = "1", features = ["full"] }
里面有个 features = ["full"] 可能大家会比较迷惑,当然,关于它的具体解释在本书的 Cargo 详解专题 有介绍,这里就简单进行说明。
Tokio 有很多功能和特性,例如 TCP,UDP,Unix sockets,同步工具,多调度类型等等,不是每个应用都需要所有的这些特性。为了优化编译时间和最终生成可执行文件大小、内存占用大小,应用可以对这些特性进行可选引入。
而这里为了演示的方便,我们使用 full ,表示直接引入所有的特性。
总结
大家对 tokio 的初印象如何?可否 24 灯全亮通过?
总之,tokio 做的事情其实是细雨润无声的,在大多数时候,我们并不能感觉到它的存在,但是它确实是异步编程中最重要的一环(或者之一),深入了解它对我们的未来之路会有莫大的帮助。
接下来,正式开始 tokio 的学习之旅。
创建异步任务
同志们,抓稳了,我们即将换挡提速,通向 mini-redis 服务端的高速之路已经开启。
不过在开始之前,先来做点收尾工作:上一章节中,我们实现了一个简易的 mini-redis 客户端并支持了 SET/GET 操作, 现在将该代码移动到 examples 文件夹下,因为我们这个章节要实现的是服务器,后面可以通过运行 example 的方式,用之前客户端示例对我们的服务器端进行测试:
$ mkdir -p examples
$ mv src/main.rs examples/hello-redis.rs
并在 Cargo.toml 里添加 [[example]] 说明。关于 example 的详细说明,可以在Cargo使用指南里进一步了解。
[[example]]
name = "hello-redis"
path = "examples/hello-redis.rs"
然后再重新创建一个空的 src/main.rs 文件,至此替换文档已经完成,提速正式开始。
接收 sockets
作为服务器端,最基础的工作无疑是接收外部进来的 TCP 连接,可以通过 tokio::net::TcpListener 来完成。
Tokio 中大多数类型的名称都和标准库中对应的同步类型名称相同,而且,如果没有特殊原因,Tokio 的 API 名称也和标准库保持一致,只不过用
async fn取代fn来声明函数。
TcpListener 监听 6379 端口,然后通过循环来接收外部进来的连接,每个连接在处理完后会被关闭。对于目前来说,我们的任务很简单:读取命令、打印到标准输出 stdout,最后回复给客户端一个错误。
use tokio::net::{TcpListener, TcpStream}; use mini_redis::{Connection, Frame}; #[tokio::main] async fn main() { // Bind the listener to the address // 监听指定地址,等待 TCP 连接进来 let listener = TcpListener::bind("127.0.0.1:6379").await.unwrap(); loop { // 第二个被忽略的项中包含有新连接的 `IP` 和端口信息 let (socket, _) = listener.accept().await.unwrap(); process(socket).await; } } async fn process(socket: TcpStream) { // `Connection` 对于 redis 的读写进行了抽象封装,因此我们读到的是一个一个数据帧frame(数据帧 = redis命令 + 数据),而不是字节流 // `Connection` 是在 mini-redis 中定义 let mut connection = Connection::new(socket); if let Some(frame) = connection.read_frame().await.unwrap() { println!("GOT: {:?}", frame); // 回复一个错误 let response = Frame::Error("unimplemented".to_string()); connection.write_frame(&response).await.unwrap(); } }
现在运行我们的简单服务器 :
cargo run
此时服务器会处于循环等待以接收连接的状态,接下来在一个新的终端窗口中启动上一章节中的 redis 客户端,由于相关代码已经放入 examples 文件夹下,因此我们可以使用 --example 来指定运行该客户端示例:
$ cargo run --example hello-redis
此时,客户端的输出是: Error: "unimplemented", 同时服务器端打印出了客户端发来的由 redis 命令和数据 组成的数据帧: GOT: Array([Bulk(b"set"), Bulk(b"hello"), Bulk(b"world")])。
生成任务
上面的服务器,如果你仔细看,它其实一次只能接受和处理一条 TCP 连接,只有等当前的处理完并结束后,才能开始接收下一条连接。原因在于 loop 循环中的 await 会导致当前任务进入阻塞等待,也就是 loop 循环会被阻塞。
而这显然不是我们想要的,服务器能并发地处理多条连接的请求,才是正确的打开姿势,下面来看看如何实现真正的并发。
关于并发和并行,在多线程章节中有详细的解释
为了并发的处理连接,需要为每一条进来的连接都生成( spawn )一个新的任务, 然后在该任务中处理连接:
use tokio::net::TcpListener; #[tokio::main] async fn main() { let listener = TcpListener::bind("127.0.0.1:6379").await.unwrap(); loop { let (socket, _) = listener.accept().await.unwrap(); // 为每一条连接都生成一个新的任务, // `socket` 的所有权将被移动到新的任务中,并在那里进行处理 tokio::spawn(async move { process(socket).await; }); } }
任务
一个 Tokio 任务是一个异步的绿色线程,它们通过 tokio::spawn 进行创建,该函数会返回一个 JoinHandle 类型的句柄,调用者可以使用该句柄跟创建的任务进行交互。
spawn 函数的参数是一个 async 语句块,该语句块甚至可以返回一个值,然后调用者可以通过 JoinHandle 句柄获取该值:
#[tokio::main] async fn main() { let handle = tokio::spawn(async { 10086 }); let out = handle.await.unwrap(); println!("GOT {}", out); }
以上代码会打印出GOT 10086。实际上,上面代码中.await 会返回一个 Result ,若 spawn 创建的任务正常运行结束,则返回一个 Ok(T)的值,否则会返回一个错误 Err:例如任务内部发生了 panic 或任务因为运行时关闭被强制取消时。
任务是调度器管理的执行单元。spawn生成的任务会首先提交给调度器,然后由它负责调度执行。需要注意的是,执行任务的线程未必是创建任务的线程,任务完全有可能运行在另一个不同的线程上,而且任务在生成后,它还可能会在线程间被移动。
任务在 Tokio 中远比看上去要更轻量,例如创建一个任务仅仅需要一次 64 字节大小的内存分配。因此应用程序在生成任务上,完全不应该有任何心理负担,除非你在一台没那么好的机器上疯狂生成了几百万个任务。。。
'static 约束
当使用 Tokio 创建一个任务时,该任务类型的生命周期必须是 'static。意味着,在任务中不能使用外部数据的引用:
use tokio::task; #[tokio::main] async fn main() { let v = vec![1, 2, 3]; task::spawn(async { println!("Here's a vec: {:?}", v); }); }
上面代码中,spawn 出的任务引用了外部环境中的变量 v ,导致以下报错:
error[E0373]: async block may outlive the current function, but
it borrows `v`, which is owned by the current function
--> src/main.rs:7:23
|
7 | task::spawn(async {
| _______________________^
8 | | println!("Here's a vec: {:?}", v);
| | - `v` is borrowed here
9 | | });
| |_____^ may outlive borrowed value `v`
|
note: function requires argument type to outlive `'static`
--> src/main.rs:7:17
|
7 | task::spawn(async {
| _________________^
8 | | println!("Here's a vector: {:?}", v);
9 | | });
| |_____^
help: to force the async block to take ownership of `v` (and any other
referenced variables), use the `move` keyword
|
7 | task::spawn(async move {
8 | println!("Here's a vec: {:?}", v);
9 | });
|
原因在于:默认情况下,变量并不是通过 move 的方式转移进 async 语句块的, v 变量的所有权依然属于 main 函数,因为任务内部的 println! 是通过借用的方式使用了 v,但是这种借用并不能满足 'static 生命周期的要求。
在报错的同时,Rust 编译器还给出了相当有帮助的提示:为 async 语句块使用 move 关键字,这样就能将 v 的所有权从 main 函数转移到新创建的任务中。
但是 move 有一个问题,一个数据只能被一个任务使用,如果想要多个任务使用一个数据,就有些强人所难。不知道还有多少同学记得 Arc,它可以轻松解决该问题,还是线程安全的。
在上面的报错中,还有一句很奇怪的信息function requires argument type to outlive `'static` , 函数要求参数类型的生命周期必须比 'static 长,问题是 'static 已经活得跟整个程序一样久了,难道函数的参数还能活得更久?大家可能会觉得编译器秀逗了,毕竟其它语言编译器也有秀逗的时候:)
先别急着给它扣帽子,虽然我有时候也想这么做。。原因是它说的是类型必须活得比 'static 长,而不是值。当我们说一个值是 'static 时,意味着它将永远存活。这个很重要,因为编译器无法知道新创建的任务将存活多久,所以唯一的办法就是让任务永远存活。
如果大家对于 '&static 和 T: 'static 较为模糊,强烈建议回顾下该章节。
Send 约束
tokio::spawn 生成的任务必须实现 Send 特征,因为当这些任务在 .await 执行过程中发生阻塞时,Tokio 调度器会将任务在线程间移动。
一个任务要实现 Send 特征,那它在 .await 调用的过程中所持有的全部数据都必须实现 Send 特征。当 .await 调用发生阻塞时,任务会让出当前线程所有权给调度器,然后当任务准备好后,调度器会从上一次暂停的位置继续执行该任务。该流程能正确的工作,任务必须将.await之后使用的所有状态保存起来,这样才能在中断后恢复现场并继续执行。若这些状态实现了 Send 特征(可以在线程间安全地移动),那任务自然也就可以在线程间安全地移动。
例如以下代码可以工作:
use tokio::task::yield_now; use std::rc::Rc; #[tokio::main] async fn main() { tokio::spawn(async { // 语句块的使用强制了 `rc` 会在 `.await` 被调用前就被释放, // 因此 `rc` 并不会影响 `.await`的安全性 { let rc = Rc::new("hello"); println!("{}", rc); } // `rc` 的作用范围已经失效,因此当任务让出所有权给当前线程时,它无需作为状态被保存起来 yield_now().await; }); }
但是下面代码就不行:
use tokio::task::yield_now; use std::rc::Rc; #[tokio::main] async fn main() { tokio::spawn(async { let rc = Rc::new("hello"); // `rc` 在 `.await` 后还被继续使用,因此它必须被作为任务的状态保存起来 yield_now().await; // 事实上,注释掉下面一行代码,依然会报错 // 原因是:是否保存,不取决于 `rc` 是否被使用,而是取决于 `.await`在调用时是否仍然处于 `rc` 的作用域中 println!("{}", rc); // rc 作用域在这里结束 }); }
这里有一个很重要的点,代码注释里有讲到,但是我们再重复一次: rc 是否会保存到任务状态中,取决于 .await 的调用是否处于它的作用域中,上面代码中,就算你注释掉 println! 函数,该报错依然会报错,因为 rc 的作用域直到 async 的末尾才结束!
下面是相应的报错,在下一章节,我们还会继续深入讨论该错误:
error: future cannot be sent between threads safely
--> src/main.rs:6:5
|
6 | tokio::spawn(async {
| ^^^^^^^^^^^^ future created by async block is not `Send`
|
::: [..]spawn.rs:127:21
|
127 | T: Future + Send + 'static,
| ---- required by this bound in
| `tokio::task::spawn::spawn`
|
= help: within `impl std::future::Future`, the trait
| `std::marker::Send` is not implemented for
| `std::rc::Rc<&str>`
note: future is not `Send` as this value is used across an await
--> src/main.rs:10:9
|
7 | let rc = Rc::new("hello");
| -- has type `std::rc::Rc<&str>` which is not `Send`
...
10 | yield_now().await;
| ^^^^^^^^^^^^^^^^^ await occurs here, with `rc` maybe
| used later
11 | println!("{}", rc);
12 | });
| - `rc` is later dropped here
使用 HashMap 存储数据
现在,我们可以继续前进了,下面来实现 process 函数,它用于处理进入的命令。相应的值将被存储在 HashMap 中: 通过 SET 命令存值,通过 GET 命令来取值。
同时,我们将使用循环的方式在同一个客户端连接中处理多次连续的请求:
#![allow(unused)] fn main() { use tokio::net::TcpStream; use mini_redis::{Connection, Frame}; async fn process(socket: TcpStream) { use mini_redis::Command::{self, Get, Set}; use std::collections::HashMap; // 使用 hashmap 来存储 redis 的数据 let mut db = HashMap::new(); // `mini-redis` 提供的便利函数,使用返回的 `connection` 可以用于从 socket 中读取数据并解析为数据帧 let mut connection = Connection::new(socket); // 使用 `read_frame` 方法从连接获取一个数据帧:一条redis命令 + 相应的数据 while let Some(frame) = connection.read_frame().await.unwrap() { let response = match Command::from_frame(frame).unwrap() { Set(cmd) => { // 值被存储为 `Vec<u8>` 的形式 db.insert(cmd.key().to_string(), cmd.value().to_vec()); Frame::Simple("OK".to_string()) } Get(cmd) => { if let Some(value) = db.get(cmd.key()) { // `Frame::Bulk` 期待数据的类型是 `Bytes`, 该类型会在后面章节讲解, // 此时,你只要知道 `&Vec<u8>` 可以使用 `into()` 方法转换成 `Bytes` 类型 Frame::Bulk(value.clone().into()) } else { Frame::Null } } cmd => panic!("unimplemented {:?}", cmd), }; // 将请求响应返回给客户端 connection.write_frame(&response).await.unwrap(); } } // main 函数在之前已实现 }
使用 cargo run 运行服务器,然后再打开另一个终端窗口,运行 hello-redis 客户端示例: cargo run --example hello-redis。
Bingo,在看了这么多原理后,我们终于迈出了小小的第一步,并获取到了存在 HashMap 中的值: 从服务器端获取到结果=Some(b"world")。
但是问题又来了:这些值无法在 TCP 连接中共享,如果另外一个用户连接上来并试图同时获取 hello 这个 key,他将一无所获。
共享状态
上一章节中,咱们搭建了一个异步的 redis 服务器,并成功的提供了服务,但是其隐藏了一个巨大的问题:状态(数据)无法在多个连接之间共享,下面一起来看看该如何解决。
解决方法
好在 Tokio 十分强大,上面问题对应的解决方法也不止一种:
- 使用
Mutex来保护数据的共享访问 - 生成一个异步任务去管理状态,然后各个连接使用消息传递的方式与其进行交互
其中,第一种方法适合比较简单的数据,而第二种方法适用于需要异步工作的,例如 I/O 原语。由于我们使用的数据存储类型是 HashMap,使用到的相关操作是 insert 和 get ,又因为这两个操作都不是异步的,因此只要使用 Mutex 即可解决问题。
在上面的描述中,说实话第二种方法及其适用的场景并不是很好理解,但没关系,在后面章节会进行详细介绍。
添加 bytes 依赖包
在上一节中,我们使用 Vec<u8> 来保存目标数据,但是它有一个问题,对它进行克隆时会将底层数据也整个复制一份,效率很低,但是克隆操作对于我们在多连接间共享数据又是必不可少的。
因此这里咱们新引入一个 bytes 包,它包含一个 Bytes 类型,当对该类型的值进行克隆时,就不再会克隆底层数据。事实上,Bytes 是一个引用计数类型,跟 Arc 非常类似,或者准确的说,Bytes 就是基于 Arc 实现的,但相比后者Bytes 提供了一些额外的能力。
在 Cargo.toml 的 [dependencies] 中引入 bytes :
bytes = "1"
初始化 HashMap
由于 HashMap 会在多个任务甚至多个线程间共享,再结合之前的选择,最终我们决定使用 Arc<Mutex<T>> 的方式对其进行包裹。
但是,大家先来畅想一下使用它进行包裹后的类型长什么样? 大概,可能,长这样:Arc<Mutex<HashMap<String, Bytes>>>,天哪噜,一不小心,你就遇到了 Rust 的阴暗面:类型大串烧。可以想象,如果要在代码中到处使用这样的类型,可读性会极速下降,因此我们需要一个类型别名( type alias )来简化下:
#![allow(unused)] fn main() { use bytes::Bytes; use std::collections::HashMap; use std::sync::{Arc, Mutex}; type Db = Arc<Mutex<HashMap<String, Bytes>>>; }
此时,Db 就是一个类型别名,使用它就可以替代那一大串的东东,等下你就能看到功效。
接着,我们需要在 main 函数中对 HashMap 进行初始化,然后使用 Arc 克隆一份它的所有权并将其传入到生成的异步任务中。事实上在 Tokio 中,这里的 Arc 被称为 handle,或者更宽泛的说,handle 在 Tokio 中可以用来访问某个共享状态。
use tokio::net::TcpListener; use std::collections::HashMap; use std::sync::{Arc, Mutex}; #[tokio::main] async fn main() { let listener = TcpListener::bind("127.0.0.1:6379").await.unwrap(); println!("Listening"); let db = Arc::new(Mutex::new(HashMap::new())); loop { let (socket, _) = listener.accept().await.unwrap(); // 将 handle 克隆一份 let db = db.clone(); println!("Accepted"); tokio::spawn(async move { process(socket, db).await; }); } }
为何使用 std::sync::Mutex
上面代码还有一点非常重要,那就是我们使用了 std::sync::Mutex 来保护 HashMap,而不是使用 tokio::sync::Mutex。
在使用 Tokio 编写异步代码时,一个常见的错误无条件地使用 tokio::sync::Mutex ,而真相是:Tokio 提供的异步锁只应该在跨多个 .await调用时使用,而且 Tokio 的 Mutex 实际上内部使用的也是 std::sync::Mutex。
多补充几句,在异步代码中,关于锁的使用有以下经验之谈:
- 锁如果在多个
.await过程中持有,应该使用 Tokio 提供的锁,原因是.await的过程中锁可能在线程间转移,若使用标准库的同步锁存在死锁的可能性,例如某个任务刚获取完锁,还没使用完就因为.await让出了当前线程的所有权,结果下个任务又去获取了锁,造成死锁 - 锁竞争不多的情况下,使用
std::sync::Mutex - 锁竞争多,可以考虑使用三方库提供的性能更高的锁,例如
parking_lot::Mutex
更新 process()
process() 函数不再初始化 HashMap,取而代之的是它使用了 HashMap 的一个 handle 作为参数:
#![allow(unused)] fn main() { use tokio::net::TcpStream; use mini_redis::{Connection, Frame}; async fn process(socket: TcpStream, db: Db) { use mini_redis::Command::{self, Get, Set}; let mut connection = Connection::new(socket); while let Some(frame) = connection.read_frame().await.unwrap() { let response = match Command::from_frame(frame).unwrap() { Set(cmd) => { let mut db = db.lock().unwrap(); db.insert(cmd.key().to_string(), cmd.value().clone()); Frame::Simple("OK".to_string()) } Get(cmd) => { let db = db.lock().unwrap(); if let Some(value) = db.get(cmd.key()) { Frame::Bulk(value.clone()) } else { Frame::Null } } cmd => panic!("unimplemented {:?}", cmd), }; connection.write_frame(&response).await.unwrap(); } } }
任务、线程和锁竞争
当竞争不多的时候,使用阻塞性的锁去保护共享数据是一个正确的选择。当一个锁竞争触发后,当前正在执行任务(请求锁)的线程会被阻塞,并等待锁被前一个使用者释放。这里的关键就是:锁竞争不仅仅会导致当前的任务被阻塞,还会导致执行任务的线程被阻塞,因此该线程准备执行的其它任务也会因此被阻塞!
默认情况下,Tokio 调度器使用了多线程模式,此时如果有大量的任务都需要访问同一个锁,那么锁竞争将变得激烈起来。当然,你也可以使用 current_thread 运行时设置,在该设置下会使用一个单线程的调度器(执行器),所有的任务都会创建并执行在当前线程上,因此不再会有锁竞争。
current_thread 是一个轻量级、单线程的运行时,当任务数不多或连接数不多时是一个很好的选择。例如你想在一个异步客户端库的基础上提供给用户同步的 API 访问时,该模式就很适用
当同步锁的竞争变成一个问题时,使用 Tokio 提供的异步锁几乎并不能帮你解决问题,此时可以考虑如下选项:
- 创建专门的任务并使用消息传递的方式来管理状态
- 将锁进行分片
- 重构代码以避免锁
在我们的例子中,由于每一个 key 都是独立的,因此对锁进行分片将成为一个不错的选择:
#![allow(unused)] fn main() { type ShardedDb = Arc<Vec<Mutex<HashMap<String, Vec<u8>>>>>; fn new_sharded_db(num_shards: usize) -> ShardedDb { let mut db = Vec::with_capacity(num_shards); for _ in 0..num_shards { db.push(Mutex::new(HashMap::new())); } Arc::new(db) } }
在这里,我们创建了 N 个不同的存储实例,每个实例都会存储不同的分片数据,例如我们有a-i共 9 个不同的 key, 可以将存储分成 3 个实例,那么第一个实例可以存储 a-c,第二个d-f,以此类推。在这种情况下,访问 b 时,只需要锁住第一个实例,此时二、三实例依然可以正常访问,因此锁被成功的分片了。
在分片后,使用给定的 key 找到对应的值就变成了两个步骤:首先,使用 key 通过特定的算法寻找到对应的分片,然后再使用该 key 从分片中查询到值:
#![allow(unused)] fn main() { let shard = db[hash(key) % db.len()].lock().unwrap(); shard.insert(key, value); }
这里我们使用 hash 算法来进行分片,但是该算法有个缺陷:分片的数量不能变,一旦变了后,那之前落入分片 1 的key很可能将落入到其它分片中,最终全部乱掉。此时你可以考虑dashmap,它提供了更复杂、更精妙的支持分片的hash map。
在 .await 期间持有锁
在某些时候,你可能会不经意写下这种代码:
#![allow(unused)] fn main() { use std::sync::{Mutex, MutexGuard}; async fn increment_and_do_stuff(mutex: &Mutex<i32>) { let mut lock: MutexGuard<i32> = mutex.lock().unwrap(); *lock += 1; do_something_async().await; } // 锁在这里超出作用域 }
如果你要 spawn 一个任务来执行上面的函数的话,会报错:
error: future cannot be sent between threads safely
--> src/lib.rs:13:5
|
13 | tokio::spawn(async move {
| ^^^^^^^^^^^^ future created by async block is not `Send`
|
::: /playground/.cargo/registry/src/github.com-1ecc6299db9ec823/tokio-0.2.21/src/task/spawn.rs:127:21
|
127 | T: Future + Send + 'static,
| ---- required by this bound in `tokio::task::spawn::spawn`
|
= help: within `impl std::future::Future`, the trait `std::marker::Send` is not implemented for `std::sync::MutexGuard<'_, i32>`
note: future is not `Send` as this value is used across an await
--> src/lib.rs:7:5
|
4 | let mut lock: MutexGuard<i32> = mutex.lock().unwrap();
| -------- has type `std::sync::MutexGuard<'_, i32>` which is not `Send`
...
7 | do_something_async().await;
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ await occurs here, with `mut lock` maybe used later
8 | }
| - `mut lock` is later dropped here
错误的原因在于 std::sync::MutexGuard 类型并没有实现 Send 特征,这意味着你不能将一个 Mutex 锁发送到另一个线程,因为 .await 可能会让任务转移到另一个线程上执行,这个之前也介绍过。
提前释放锁
要解决这个问题,就必须重构代码,让 Mutex 锁在 .await 被调用前就被释放掉。
#![allow(unused)] fn main() { // 下面的代码可以工作! async fn increment_and_do_stuff(mutex: &Mutex<i32>) { { let mut lock: MutexGuard<i32> = mutex.lock().unwrap(); *lock += 1; } // lock在这里超出作用域 (被释放) do_something_async().await; } }
大家可能已经发现,很多错误都是因为
.await引起的,其实你只要记住,在.await执行期间,任务可能会在线程间转移,那么这些错误将变得很好理解,不必去死记硬背
但是下面的代码不工作:
#![allow(unused)] fn main() { use std::sync::{Mutex, MutexGuard}; async fn increment_and_do_stuff(mutex: &Mutex<i32>) { let mut lock: MutexGuard<i32> = mutex.lock().unwrap(); *lock += 1; drop(lock); do_something_async().await; } }
原因我们之前解释过,编译器在这里不够聪明,目前它只能根据作用域的范围来判断,drop 虽然释放了锁,但是锁的作用域依然会持续到函数的结束,未来也许编译器会改进,但是现在至少还是不行的。
聪明的读者此时的小脑袋已经飞速运转起来,既然锁没有实现 Send, 那我们主动给它实现如何?这样不就可以顺利运行了吗?答案依然是不可以,原因就是我们之前提到过的死锁,如果一个任务获取了锁,然后还没释放就在 .await 期间被挂起,接着开始执行另一个任务,这个任务又去获取锁,就会导致死锁。
再来看看其它解决方法:
重构代码:在 .await 期间不持有锁
之前的代码其实也是为了在 .await 期间不持有锁,但是我们还有更好的实现方式,例如,你可以把 Mutex 放入一个结构体中,并且只在该结构体的非异步方法中使用该锁:
#![allow(unused)] fn main() { use std::sync::Mutex; struct CanIncrement { mutex: Mutex<i32>, } impl CanIncrement { // 该方法不是 `async` fn increment(&self) { let mut lock = self.mutex.lock().unwrap(); *lock += 1; } } async fn increment_and_do_stuff(can_incr: &CanIncrement) { can_incr.increment(); do_something_async().await; } }
使用异步任务和通过消息传递来管理状态
该方法常常用于共享的资源是 I/O 类型的资源时,我们在下一章节将详细介绍。
使用 Tokio 提供的异步锁
Tokio 提供的锁最大的优点就是:它可以在 .await 执行期间被持有,而且不会有任何问题。但是代价就是,这种异步锁的性能开销会更高,因此如果可以,使用之前的两种方法来解决会更好。
#![allow(unused)] fn main() { use tokio::sync::Mutex; // 注意,这里使用的是 Tokio 提供的锁 // 下面的代码会编译 // 但是就这个例子而言,之前的方式会更好 async fn increment_and_do_stuff(mutex: &Mutex<i32>) { let mut lock = mutex.lock().await; *lock += 1; do_something_async().await; } // 锁在这里被释放 }
消息传递
迄今为止,你已经学了不少关于 Tokio 的并发编程的内容,是时候见识下真正的挑战了,接下来,我们一起来实现下客户端这块儿的功能。
首先,将之前实现的 src/main.rs 文件中的服务器端代码放入到一个 bin 文件中,等下可以直接通过该文件来运行我们的服务器:
mkdir src/bin
mv src/main.rs src/bin/server.rs
接着创建一个新的 bin 文件,用于包含我们即将实现的客户端代码:
touch src/bin/client.rs
由于不再使用 main.rs 作为程序入口,我们需要使用以下命令来运行指定的 bin 文件:
#![allow(unused)] fn main() { cargo run --bin server }
此时,服务器已经成功运行起来。 同样的,可以用 cargo run --bin client 这种方式运行即将实现的客户端。
万事俱备,只欠代码,一起来看看客户端该如何实现。
错误的实现
如果想要同时运行两个 redis 命令,我们可能会为每一个命令生成一个任务,例如:
use mini_redis::client; #[tokio::main] async fn main() { // 创建到服务器的连接 let mut client = client::connect("127.0.0.1:6379").await.unwrap(); // 生成两个任务,一个用于获取 key, 一个用于设置 key let t1 = tokio::spawn(async { let res = client.get("hello").await; }); let t2 = tokio::spawn(async { client.set("foo", "bar".into()).await; }); t1.await.unwrap(); t2.await.unwrap(); }
这段代码不会编译,因为两个任务都需要去访问 client,但是 client 并没有实现 Copy 特征,再加上我们并没有实现相应的共享代码,因此自然会报错。还有一个问题,方法 set 和 get 都使用了 client 的可变引用 &mut self,由此还会造成同时借用两个可变引用的错误。
在上一节中,我们介绍了几个解决方法,但是它们大部分都不太适用于此时的情况,例如:
std::sync::Mutex无法被使用,这个问题在之前章节有详解介绍过,同步锁无法跨越.await调用时使用- 那么你可能会想,是不是可以使用
tokio::sync:Mutex,答案是可以用,但是同时就只能运行一个请求。若客户端实现了 redis 的 pipelining, 那这个异步锁就会导致连接利用率不足
这个不行,那个也不行,是不是没有办法解决了?还记得我们上一章节提到过几次的消息传递,但是一直没有看到它的庐山真面目吗?现在可以来看看了。
消息传递
之前章节我们提到可以创建一个专门的任务 C1 (消费者 Consumer) 和通过消息传递来管理共享的资源,这里的共享资源就是 client 。若任务 P1 (生产者 Producer) 想要发出 Redis 请求,首先需要发送信息给 C1,然后 C1 会发出请求给服务器,在获取到结果后,再将结果返回给 P1。
在这种模式下,只需要建立一条连接,然后由一个统一的任务来管理 client 和该连接,这样之前的 get 和 set 请求也将不存在资源共享的问题。
同时,P1 和 C1 进行通信的消息通道是有缓冲的,当大量的消息发送给 C1 时,首先会放入消息通道的缓冲区中,当 C1 处理完一条消息后,再从该缓冲区中取出下一条消息进行处理,这种方式跟消息队列( Message queue ) 非常类似,可以实现更高的吞吐。而且这种方式还有利于实现连接池,例如不止一个 P 和 C 时,多个 P 可以往消息通道中发送消息,同时多个 C,其中每个 C 都维护一条连接,并从消息通道获取消息。
Tokio 的消息通道( channel )
Tokio 提供了多种消息通道,可以满足不同场景的需求:
mpsc, 多生产者,单消费者模式oneshot, 单生产者单消费,一次只能发送一条消息broadcast,多生产者,多消费者,其中每一条发送的消息都可以被所有接收者收到,因此是广播watch,单生产者,多消费者,只保存一条最新的消息,因此接收者只能看到最近的一条消息,例如,这种模式适用于配置文件变化的监听
细心的同学可能会发现,这里还少了一种类型:多生产者、多消费者,且每一条消息只能被其中一个消费者接收,如果有这种需求,可以使用 async-channel 包。
以上这些消息通道都有一个共同点:适用于 async 编程,对于其它场景,你可以使用在多线程章节中提到过的 std::sync::mpsc 和 crossbeam::channel, 这些通道在等待消息时会阻塞当前的线程,因此不适用于 async 编程。
在下面的代码中,我们将使用 mpsc 和 oneshot, 本章节完整的代码见这里。
定义消息类型
在大多数场景中使用消息传递时,都是多个发送者向一个任务发送消息,该任务在处理完后,需要将响应内容返回给相应的发送者。例如我们的例子中,任务需要将 GET 和 SET 命令处理的结果返回。首先,我们需要定一个 Command 枚举用于代表命令:
#![allow(unused)] fn main() { use bytes::Bytes; #[derive(Debug)] enum Command { Get { key: String, }, Set { key: String, val: Bytes, } } }
创建消息通道
在 src/bin/client.rs 的 main 函数中,创建一个 mpsc 消息通道:
use tokio::sync::mpsc; #[tokio::main] async fn main() { // 创建一个新通道,缓冲队列长度是 32 let (tx, mut rx) = mpsc::channel(32); // ... 其它代码 }
一个任务可以通过此通道将命令发送给管理 redis 连接的任务,同时由于通道支持多个生产者,因此多个任务可以同时发送命令。创建该通道会返回一个发送和接收句柄,这两个句柄可以分别被使用,例如它们可以被移动到不同的任务中。
通道的缓冲队列长度是 32,意味着如果消息发送的比接收的快,这些消息将被存储在缓冲队列中,一旦存满了 32 条消息,使用send(...).await的发送者会进入睡眠,直到缓冲队列可以放入新的消息(被接收者消费了)。
use tokio::sync::mpsc; #[tokio::main] async fn main() { let (tx, mut rx) = mpsc::channel(32); let tx2 = tx.clone(); tokio::spawn(async move { tx.send("sending from first handle").await; }); tokio::spawn(async move { tx2.send("sending from second handle").await; }); while let Some(message) = rx.recv().await { println!("GOT = {}", message); } }
你可以使用 clone 方法克隆多个发送者,但是接收者无法被克隆,因为我们的通道是 mpsc 类型。
当所有的发送者都被 Drop 掉后(超出作用域或被 drop(...) 函数主动释放),就不再会有任何消息发送给该通道,此时 recv 方法将返回 None,也意味着该通道已经被关闭。
在我们的例子中,接收者是在管理 redis 连接的任务中,当该任务发现所有发送者都关闭时,它知道它的使命可以完成了,因此它会关闭 redis 连接。
生成管理任务
下面,我们来一起创建一个管理任务,它会管理 redis 的连接,当然,首先需要创建一条到 redis 的连接:
#![allow(unused)] fn main() { use mini_redis::client; // 将消息通道接收者 rx 的所有权转移到管理任务中 let manager = tokio::spawn(async move { // Establish a connection to the server // 建立到 redis 服务器的连接 let mut client = client::connect("127.0.0.1:6379").await.unwrap(); // 开始接收消息 while let Some(cmd) = rx.recv().await { use Command::*; match cmd { Get { key } => { client.get(&key).await; } Set { key, val } => { client.set(&key, val).await; } } } }); }
如上所示,当从消息通道接收到一个命令时,该管理任务会将此命令通过 redis 连接发送到服务器。
现在,让两个任务发送命令到消息通道,而不是像最开始报错的那样,直接发送命令到各自的 redis 连接:
#![allow(unused)] fn main() { // 由于有两个任务,因此我们需要两个发送者 let tx2 = tx.clone(); // 生成两个任务,一个用于获取 key,一个用于设置 key let t1 = tokio::spawn(async move { let cmd = Command::Get { key: "hello".to_string(), }; tx.send(cmd).await.unwrap(); }); let t2 = tokio::spawn(async move { let cmd = Command::Set { key: "foo".to_string(), val: "bar".into(), }; tx2.send(cmd).await.unwrap(); }); }
在 main 函数的末尾,我们让 3 个任务,按照需要的顺序开始运行:
#![allow(unused)] fn main() { t1.await.unwrap(); t2.await.unwrap(); manager.await.unwrap(); }
接收响应消息
最后一步,就是让发出命令的任务从管理任务那里获取命令执行的结果。为了完成这个目标,我们将使用 oneshot 消息通道,因为它针对一发一收的使用类型做过特别优化,且特别适用于此时的场景:接收一条从管理任务发送的结果消息。
#![allow(unused)] fn main() { use tokio::sync::oneshot; let (tx, rx) = oneshot::channel(); }
使用方式跟 mpsc 很像,但是它并没有缓存长度,因为只能发送一条,接收一条,还有一点不同:你无法对返回的两个句柄进行 clone。
为了让管理任务将结果准确的返回到发送者手中,这个管道的发送端必须要随着命令一起发送, 然后发出命令的任务保留管道的发送端。一个比较好的实现就是将管道的发送端放入 Command 的数据结构中,同时使用一个别名来代表该发送端:
#![allow(unused)] fn main() { use tokio::sync::oneshot; use bytes::Bytes; #[derive(Debug)] enum Command { Get { key: String, resp: Responder<Option<Bytes>>, }, Set { key: String, val: Bytes, resp: Responder<()>, }, } /// 管理任务可以使用该发送端将命令执行的结果传回给发出命令的任务 type Responder<T> = oneshot::Sender<mini_redis::Result<T>>; }
下面,更新发送命令的代码:
#![allow(unused)] fn main() { let t1 = tokio::spawn(async move { let (resp_tx, resp_rx) = oneshot::channel(); let cmd = Command::Get { key: "hello".to_string(), resp: resp_tx, }; // 发送 GET 请求 tx.send(cmd).await.unwrap(); // 等待回复 let res = resp_rx.await; println!("GOT = {:?}", res); }); let t2 = tokio::spawn(async move { let (resp_tx, resp_rx) = oneshot::channel(); let cmd = Command::Set { key: "foo".to_string(), val: "bar".into(), resp: resp_tx, }; // 发送 SET 请求 tx2.send(cmd).await.unwrap(); // 等待回复 let res = resp_rx.await; println!("GOT = {:?}", res); }); }
最后,更新管理任务:
#![allow(unused)] fn main() { while let Some(cmd) = rx.recv().await { match cmd { Command::Get { key, resp } => { let res = client.get(&key).await; // 忽略错误 let _ = resp.send(res); } Command::Set { key, val, resp } => { let res = client.set(&key, val).await; // 忽略错误 let _ = resp.send(res); } } } }
有一点值得注意,往 oneshot 中发送消息时,并没有使用 .await,原因是该发送操作要么直接成功、要么失败,并不需要等待。
当 oneshot 的接受端被 drop 后,继续发送消息会直接返回 Err 错误,它表示接收者已经不感兴趣了。对于我们的场景,接收者不感兴趣是非常合理的操作,并不是一种错误,因此可以直接忽略。
本章的完整代码见这里。
对消息通道进行限制
无论何时使用消息通道,我们都需要对缓存队列的长度进行限制,这样系统才能优雅的处理各种负载状况。如果不限制,假设接收端无法及时处理消息,那消息就会迅速堆积,最终可能会导致内存消耗殆尽,就算内存没有消耗完,也可能会导致整体性能的大幅下降。
Tokio 在设计时就考虑了这种状况,例如 async 操作在 Tokio 中是惰性的:
#![allow(unused)] fn main() { loop { async_op(); } }
如果上面代码中,async_op 不是惰性的,而是在每次循环时立即执行,那该循环会立即将一个 async_op 发送到缓冲队列中,然后开始执行下一个循环,因为无需等待任务执行完成,这种发送速度是非常恐怖的,一秒钟可能会有几十万、上百万的消息发送到消息队列中。在其它语言编程中,相信大家也或多或少遇到过这种情况。
然后在 Async Rust 和 Tokio 中,上面的代码 async_op 根本就不会运行,也就不会往消息队列中写入消息。原因是我们没有调用 .await,就算使用了 .await 上面的代码也不会有问题,因为只有等当前循环的任务结束后,才会开始下一次循环。
#![allow(unused)] fn main() { loop { // 当前 `async_op` 完成后,才会开始下一次循环 async_op().await; } }
总之,在 Tokio 中我们必须要显式地引入并发和队列:
tokio::spawnselect!join!mpsc::channel
当这么做时,我们需要小心的控制并发度来确保系统的安全。例如,当使用一个循环去接收 TCP 连接时,你要确保当前打开的 socket 数量在可控范围内,而不是毫无原则的接收连接。 再比如,当使用 mpsc::channel 时,要设置一个缓冲值。
挑选一个合适的限制值是 Tokio 编程中很重要的一部分,可以帮助我们的系统更加安全、可靠的运行。
I/O
本章节中我们将深入学习 Tokio 中的 I/O 操作,了解它的原理以及该如何使用。
Tokio 中的 I/O 操作和 std 在使用方式上几无区别,最大的区别就是前者是异步的,例如 Tokio 的读写特征分别是 AsyncRead 和 AsyncWrite:
- 有部分类型按照自己的所需实现了它们:
TcpStream,File,Stdout - 还有数据结构也实现了它们:
Vec<u8>、&[u8],这样就可以直接使用这些数据结构作为读写器( reader / writer)
AsyncRead 和 AsyncWrite
这两个特征为字节流的异步读写提供了便利,通常我们会使用 AsyncReadExt 和 AsyncWriteExt 提供的工具方法,这些方法都使用 async 声明,且需要通过 .await 进行调用,
async fn read
AsyncReadExt::read 是一个异步方法可以将数据读入缓冲区( buffer )中,然后返回读取的字节数。
use tokio::fs::File; use tokio::io::{self, AsyncReadExt}; #[tokio::main] async fn main() -> io::Result<()> { let mut f = File::open("foo.txt").await?; let mut buffer = [0; 10]; // 由于 buffer 的长度限制,当次的 `read` 调用最多可以从文件中读取 10 个字节的数据 let n = f.read(&mut buffer[..]).await?; println!("The bytes: {:?}", &buffer[..n]); Ok(()) }
需要注意的是:当 read 返回 Ok(0) 时,意味着字节流( stream )已经关闭,在这之后继续调用 read 会立刻完成,依然获取到返回值 Ok(0)。 例如,字节流如果是 TcpStream 类型,那 Ok(0) 说明该连接的读取端已经被关闭(写入端关闭,会报其它的错误)。
async fn read_to_end
AsyncReadExt::read_to_end 方法会从字节流中读取所有的字节,直到遇到 EOF :
use tokio::io::{self, AsyncReadExt}; use tokio::fs::File; #[tokio::main] async fn main() -> io::Result<()> { let mut f = File::open("foo.txt").await?; let mut buffer = Vec::new(); // 读取整个文件的内容 f.read_to_end(&mut buffer).await?; Ok(()) }
async fn write
AsyncWriteExt::write 异步方法会尝试将缓冲区的内容写入到写入器( writer )中,同时返回写入的字节数:
use tokio::io::{self, AsyncWriteExt}; use tokio::fs::File; #[tokio::main] async fn main() -> io::Result<()> { let mut file = File::create("foo.txt").await?; let n = file.write(b"some bytes").await?; println!("Wrote the first {} bytes of 'some bytes'.", n); Ok(()) }
上面代码很清晰,但是大家可能会疑惑 b"some bytes" 是什么意思。这种写法可以将一个 &str 字符串转变成一个字节数组:&[u8;10],然后 write 方法又会将这个 &[u8;10] 的数组类型隐式强转为数组切片: &[u8]。
async fn write_all
AsyncWriteExt::write_all 将缓冲区的内容全部写入到写入器中:
use tokio::io::{self, AsyncWriteExt}; use tokio::fs::File; #[tokio::main] async fn main() -> io::Result<()> { let mut file = File::create("foo.txt").await?; file.write_all(b"some bytes").await?; Ok(()) }
以上只是部分方法,实际上还有一些实用的方法由于篇幅有限无法列出,大家可以通过 API 文档 查看完整的列表。
实用函数
另外,和标准库一样, tokio::io 模块包含了多个实用的函数或 API,可以用于处理标准输入/输出/错误等。
例如,tokio::io::copy 异步的将读取器( reader )中的内容拷贝到写入器( writer )中。
use tokio::fs::File; use tokio::io; #[tokio::main] async fn main() -> io::Result<()> { let mut reader: &[u8] = b"hello"; let mut file = File::create("foo.txt").await?; io::copy(&mut reader, &mut file).await?; Ok(()) }
还记得我们之前提到的字节数组 &[u8] 实现了 AsyncRead 吗?正因为这个原因,所以这里可以直接将 &u8 用作读取器。
回声服务( Echo )
就如同写代码必写 hello, world,实现 web 服务器,往往会选择实现一个回声服务。该服务会将用户的输入内容直接返回给用户,就像回声壁一样。
具体来说,就是从用户建立的 TCP 连接的 socket 中读取到数据,然后立刻将同样的数据写回到该 socket 中。因此客户端会收到和自己发送的数据一模一样的回复。
下面我们将使用两种稍有不同的方法实现该回声服务。
使用 io::copy()
先来创建一个新的 bin 文件,用于运行我们的回声服务:
touch src/bin/echo-server-copy.rs
然后可以通过以下命令运行它(跟上一章节的方式相同):
cargo run --bin echo-server-copy
至于客户端,可以简单的使用 telnet 的方式来连接,或者也可以使用 tokio::net::TcpStream,它的文档示例非常适合大家进行参考。
先来实现一下基本的服务器框架:通过 loop 循环接收 TCP 连接,然后为每一条连接创建一个单独的任务去处理。
use tokio::io; use tokio::net::TcpListener; #[tokio::main] async fn main() -> io::Result<()> { let listener = TcpListener::bind("127.0.0.1:6142").await?; loop { let (mut socket, _) = listener.accept().await?; tokio::spawn(async move { // 在这里拷贝数据 }); } }
下面,来看看重头戏 io::copy ,它有两个参数:一个读取器,一个写入器,然后将读取器中的数据直接拷贝到写入器中,类似的实现代码如下:
#![allow(unused)] fn main() { io::copy(&mut socket, &mut socket).await }
这段代码相信大家一眼就能看出问题,由于我们的读取器和写入器都是同一个 socket,因此需要对其进行两次可变借用,这明显违背了 Rust 的借用规则。
分离读写器
显然,使用同一个 socket 是不行的,为了实现目标功能,必须将 socket 分离成一个读取器和写入器。
任何一个读写器( reader + writer )都可以使用 io::split 方法进行分离,最终返回一个读取器和写入器,这两者可以独自的使用,例如可以放入不同的任务中。
例如,我们的回声客户端可以这样实现,以实现同时并发读写:
use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; #[tokio::main] async fn main() -> io::Result<()> { let socket = TcpStream::connect("127.0.0.1:6142").await?; let (mut rd, mut wr) = io::split(socket); // 创建异步任务,在后台写入数据 tokio::spawn(async move { wr.write_all(b"hello\r\n").await?; wr.write_all(b"world\r\n").await?; // 有时,我们需要给予 Rust 一些类型暗示,它才能正确的推导出类型 Ok::<_, io::Error>(()) }); let mut buf = vec![0; 128]; loop { let n = rd.read(&mut buf).await?; if n == 0 { break; } println!("GOT {:?}", &buf[..n]); } Ok(()) }
实际上,io::split 可以用于任何同时实现了 AsyncRead 和 AsyncWrite 的值,它的内部使用了 Arc 和 Mutex 来实现相应的功能。如果大家觉得这种实现有些重,可以使用 Tokio 提供的 TcpStream,它提供了两种方式进行分离:
TcpStream::split会获取字节流的引用,然后将其分离成一个读取器和写入器。但由于使用了引用的方式,它们俩必须和split在同一个任务中。 优点就是,这种实现没有性能开销,因为无需Arc和Mutex。TcpStream::into_split还提供了一种分离实现,分离出来的结果可以在任务间移动,内部是通过Arc实现
再来分析下我们的使用场景,由于 io::copy() 调用时所在的任务和 split 所在的任务是同一个,因此可以使用性能最高的 TcpStream::split:
#![allow(unused)] fn main() { tokio::spawn(async move { let (mut rd, mut wr) = socket.split(); if io::copy(&mut rd, &mut wr).await.is_err() { eprintln!("failed to copy"); } }); }
使用 io::copy 实现的完整代码见此处。
手动拷贝
程序员往往拥有一颗手动干翻一切的心,因此如果你不想用 io::copy 来简单实现,还可以自己手动去拷贝数据:
use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; #[tokio::main] async fn main() -> io::Result<()> { let listener = TcpListener::bind("127.0.0.1:6142").await?; loop { let (mut socket, _) = listener.accept().await?; tokio::spawn(async move { let mut buf = vec![0; 1024]; loop { match socket.read(&mut buf).await { // 返回值 `Ok(0)` 说明对端已经关闭 Ok(0) => return, Ok(n) => { // Copy the data back to socket // 将数据拷贝回 socket 中 if socket.write_all(&buf[..n]).await.is_err() { // 非预期错误,由于我们这里无需再做什么,因此直接停止处理 return; } } Err(_) => { // 非预期错误,由于我们无需再做什么,因此直接停止处理 return; } } } }); } }
建议这段代码放入一个和之前 io::copy 不同的文件中 src/bin/echo-server.rs , 然后使用 cargo run --bin echo-server 运行。
下面一起来看看这段代码有哪些值得注意的地方。首先,由于使用了 write_all 和 read 方法,需要先将对应的特征引入到当前作用域内:
#![allow(unused)] fn main() { use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; }
在堆上分配缓冲区
在上面代码中,我们需要将数据从 socket 中读取到一个缓冲区 buffer 中:
#![allow(unused)] fn main() { let mut buf = vec![0; 1024]; }
可以看到,此处的缓冲区是一个 Vec 动态数组,它的数据是存储在堆上,而不是栈上(若改成 let mut buf = [0; 1024];,则存储在栈上)。
在之前,我们提到过一个数据如果想在 .await 调用过程中存在,那它必须存储在当前任务内。在我们的代码中,buf 会在 .await 调用过程中被使用,因此它必须要存储在任务内。
若该缓冲区数组创建在栈上,那每条连接所对应的任务的内部数据结构看上去可能如下所示:
#![allow(unused)] fn main() { struct Task { task: enum { AwaitingRead { socket: TcpStream, buf: [BufferType], }, AwaitingWriteAll { socket: TcpStream, buf: [BufferType], } } } }
可以看到,栈数组要被使用,就必须存储在相应的结构体内,其中两个结构体分别持有了不同的栈数组 [BufferType],这种方式会导致任务结构变得很大。特别地,我们选择缓冲区长度往往会使用分页长度(page size),因此使用栈数组会导致任务的内存大小变得很奇怪甚至糟糕:$page-size + 一些额外的字节。
当然,编译器会帮助我们做一些优化。例如,会进一步优化 async 语句块的布局,而不是像上面一样简单的使用 enum。在实践中,变量也不会在枚举成员间移动。
但是再怎么优化,任务的结构体至少也会跟其中的栈数组一样大,因此通常情况下,使用堆上的缓冲区会高效实用的多。
当任务因为调度在线程间移动时,存储在栈上的数据需要进行保存和恢复,过大的栈上变量会带来不小的数据拷贝开销
因此,存储大量数据的变量最好放到堆上
处理 EOF
当 TCP 连接的读取端关闭后,再调用 read 方法会返回 Ok(0)。此时,再继续下去已经没有意义,因此我们需要退出循环。忘记在 EOF 时退出读取循环,是网络编程中一个常见的 bug :
#![allow(unused)] fn main() { loop { match socket.read(&mut buf).await { Ok(0) => return, // ... 其余错误处理 } } }
大家不妨深入思考下,如果没有退出循环会怎么样?之前我们提到过,一旦读取端关闭后,那后面的 read 调用就会立即返回 Ok(0),而不会阻塞等待,因此这种无阻塞循环会最终导致 CPU 立刻跑到 100% ,并将一直持续下去,直到程序关闭。
解析数据帧
现在,鉴于大家已经掌握了 Tokio 的基本 I/O 用法,我们可以开始实现 mini-redis 的帧 frame。通过帧可以将字节流转换成帧组成的流。每个帧就是一个数据单元,例如客户端发送的一次请求就是一个帧。
#![allow(unused)] fn main() { use bytes::Bytes; enum Frame { Simple(String), Error(String), Integer(u64), Bulk(Bytes), Null, Array(Vec<Frame>), } }
可以看到帧除了数据之外,并不具备任何语义。命令解析和实现会在更高的层次进行(相比帧解析层)。我们再来通过 HTTP 的帧来帮大家加深下相关的理解:
#![allow(unused)] fn main() { enum HttpFrame { RequestHead { method: Method, uri: Uri, version: Version, headers: HeaderMap, }, ResponseHead { status: StatusCode, version: Version, headers: HeaderMap, }, BodyChunk { chunk: Bytes, }, } }
为了实现 mini-redis 的帧,我们需要一个 Connection 结构体,里面包含了一个 TcpStream 以及对帧进行读写的方法:
#![allow(unused)] fn main() { use tokio::net::TcpStream; use mini_redis::{Frame, Result}; struct Connection { stream: TcpStream, // ... 这里定义其它字段 } impl Connection { /// 从连接读取一个帧 /// /// 如果遇到EOF,则返回 None pub async fn read_frame(&mut self) -> Result<Option<Frame>> { // 具体实现 } /// 将帧写入到连接中 pub async fn write_frame(&mut self, frame: &Frame) -> Result<()> { // 具体实现 } } }
关于 Redis 协议的说明,可以看看官方文档,Connection 代码的完整实现见这里.
缓冲读取(Buffered Reads)
read_frame 方法会等到一个完整的帧都读取完毕后才返回,与之相比,它底层调用的TcpStream::read 只会返回任意多的数据(填满传入的缓冲区 buffer ),它可能返回帧的一部分、一个帧、多个帧,总之这种读取行为是不确定的。
当 read_frame 的底层调用 TcpStream::read 读取到部分帧时,会将数据先缓冲起来,接着继续等待并读取数据。如果读到多个帧,那第一个帧会被返回,然后剩下的数据依然被缓冲起来,等待下一次 read_frame 被调用。
为了实现这种功能,我们需要为 Connection 增加一个读取缓冲区。数据首先从 socket 中读取到缓冲区中,接着这些数据会被解析为帧,当一个帧被解析后,该帧对应的数据会从缓冲区被移除。
这里使用 BytesMut 作为缓冲区类型,它是 Bytes 的可变版本。
#![allow(unused)] fn main() { use bytes::BytesMut; use tokio::net::TcpStream; pub struct Connection { stream: TcpStream, buffer: BytesMut, } impl Connection { pub fn new(stream: TcpStream) -> Connection { Connection { stream, // 分配一个缓冲区,具有4kb的缓冲长度 buffer: BytesMut::with_capacity(4096), } } } }
接下来,实现 read_frame 方法:
#![allow(unused)] fn main() { use tokio::io::AsyncReadExt; use bytes::Buf; use mini_redis::Result; pub async fn read_frame(&mut self) -> Result<Option<Frame>> { loop { // 尝试从缓冲区的数据中解析出一个数据帧, // 只有当数据足够被解析时,才返回对应的帧 if let Some(frame) = self.parse_frame()? { return Ok(Some(frame)); } // 如果缓冲区中的数据还不足以被解析为一个数据帧, // 那么我们需要从 socket 中读取更多的数据 // // 读取成功时,会返回读取到的字节数,0 代表着读到了数据流的末尾 if 0 == self.stream.read_buf(&mut self.buffer).await? { // 代码能执行到这里,说明了对端关闭了连接, // 需要看看缓冲区是否还有数据,若没有数据,说明所有数据成功被处理, // 若还有数据,说明对端在发送帧的过程中断开了连接,导致只发送了部分数据 if self.buffer.is_empty() { return Ok(None); } else { return Err("connection reset by peer".into()); } } } } }
read_frame 内部使用循环的方式读取数据,直到一个完整的帧被读取到时,才会返回。当然,当远程的对端关闭了连接后,也会返回。
Buf 特征
在上面的 read_frame 方法中,我们使用了 read_buf 来读取 socket 中的数据,该方法的参数是来自 bytes 包的 BufMut。
可以先来考虑下该如何使用 read() 和 Vec<u8> 来实现同样的功能 :
#![allow(unused)] fn main() { use tokio::net::TcpStream; pub struct Connection { stream: TcpStream, buffer: Vec<u8>, cursor: usize, } impl Connection { pub fn new(stream: TcpStream) -> Connection { Connection { stream, // 4kb 大小的缓冲区 buffer: vec![0; 4096], cursor: 0, } } } }
下面是相应的 read_frame 方法:
#![allow(unused)] fn main() { use mini_redis::{Frame, Result}; pub async fn read_frame(&mut self) -> Result<Option<Frame>> { loop { if let Some(frame) = self.parse_frame()? { return Ok(Some(frame)); } // 确保缓冲区长度足够 if self.buffer.len() == self.cursor { // 若不够,需要增加缓冲区长度 self.buffer.resize(self.cursor * 2, 0); } // 从游标位置开始将数据读入缓冲区 let n = self.stream.read( &mut self.buffer[self.cursor..]).await?; if 0 == n { if self.cursor == 0 { return Ok(None); } else { return Err("connection reset by peer".into()); } } else { // 更新游标位置 self.cursor += n; } } } }
在这段代码中,我们使用了非常重要的技术:通过游标( cursor )跟踪已经读取的数据,并将下次读取的数据写入到游标之后的缓冲区中,只有这样才不会让新读取的数据将之前读取的数据覆盖掉。
一旦缓冲区满了,还需要增加缓冲区的长度,这样才能继续写入数据。还有一点值得注意,在 parse_frame 方法的内部实现中,也需要通过游标来解析数据: self.buffer[..self.cursor],通过这种方式,我们可以准确获取到目前已经读取的全部数据。
在网络编程中,通过字节数组和游标的方式读取数据是非常普遍的,因此 bytes 包提供了一个 Buf 特征,如果一个类型可以被读取数据,那么该类型需要实现 Buf 特征。与之对应,当一个类型可以被写入数据时,它需要实现 BufMut 。
当 T: BufMut ( 特征约束,说明类型 T 实现了 BufMut 特征 ) 被传给 read_buf() 方法时,缓冲区 T 的内部游标会自动进行更新。正因为如此,在使用了 BufMut 版本的 read_frame 中,我们并不需要管理自己的游标。
除了游标之外,Vec<u8> 的使用也值得关注,该缓冲区在使用时必须要被初始化: vec![0; 4096],该初始化会创建一个 4096 字节长度的数组,然后将数组的每个元素都填充上 0 。当缓冲区长度不足时,新创建的缓冲区数组依然会使用 0 被重新填充一遍。 事实上,这种初始化过程会存在一定的性能开销。
与 Vec<u8> 相反, BytesMut 和 BufMut 就没有这个问题,它们无需被初始化,而且 BytesMut 还会阻止我们读取未初始化的内存。
帧解析
在理解了该如何读取数据后, 再来看看该如何通过两个部分解析出一个帧:
- 确保有一个完整的帧已经被写入了缓冲区,找到该帧的最后一个字节所在的位置
- 解析帧
#![allow(unused)] fn main() { use mini_redis::{Frame, Result}; use mini_redis::frame::Error::Incomplete; use bytes::Buf; use std::io::Cursor; fn parse_frame(&mut self) -> Result<Option<Frame>> { // 创建 `T: Buf` 类型 let mut buf = Cursor::new(&self.buffer[..]); // 检查是否读取了足够解析出一个帧的数据 match Frame::check(&mut buf) { Ok(_) => { // 获取组成该帧的字节数 let len = buf.position() as usize; // 在解析开始之前,重置内部的游标位置 buf.set_position(0); // 解析帧 let frame = Frame::parse(&mut buf)?; // 解析完成,将缓冲区该帧的数据移除 self.buffer.advance(len); // 返回解析出的帧 Ok(Some(frame)) } // 缓冲区的数据不足以解析出一个完整的帧 Err(Incomplete) => Ok(None), // 遇到一个错误 Err(e) => Err(e.into()), } } }
完整的 Frame::check 函数实现在这里,感兴趣的同学可以看看,在这里我们不会对它进行完整的介绍。
值得一提的是, Frame::check 使用了 Buf 的字节迭代风格的 API。例如,为了解析一个帧,首先需要检查它的第一个字节,该字节用于说明帧的类型。这种首字节检查是通过 Buf::get_u8 函数完成的,该函数会获取游标所在位置的字节,然后将游标位置向右移动一个字节。
缓冲写入(Buffered writes)
关于帧操作的另一个 API 是 write_frame(frame) 函数,它会将一个完整的帧写入到 socket 中。 每一次写入,都会触发一次或数次系统调用,当程序中有大量的连接和写入时,系统调用的开销将变得非常高昂,具体可以看看 SyllaDB 团队写过的一篇性能调优文章。
为了降低系统调用的次数,我们需要使用一个写入缓冲区,当写入一个帧时,首先会写入该缓冲区,然后等缓冲区数据足够多时,再集中将其中的数据写入到 socket 中,这样就将多次系统调用优化减少到一次。
还有,缓冲区也不总是能提升性能。 例如,考虑一个 bulk 帧(多个帧放在一起组成一个 bulk,通过批量发送提升效率),该帧的特点就是:由于由多个帧组合而成,因此帧体数据可能会很大。所以我们不能将其帧体数据写入到缓冲区中,因为数据较大时,先写入缓冲区再写入 socket 会有较大的性能开销(实际上缓冲区就是为了批量写入,既然 bulk 已经是批量了,因此不使用缓冲区也很正常)。
为了实现缓冲写,我们将使用 BufWriter 结构体。该结构体实现了 AsyncWrite 特征,当 write 方法被调用时,不会直接写入到 socket 中,而是先写入到缓冲区中。当缓冲区被填满时,其中的内容会自动刷到(写入到)内部的 socket 中,然后再将缓冲区清空。当然,其中还存在某些优化,通过这些优化可以绕过缓冲区直接访问 socket。
由于篇幅有限,我们不会实现完整的 write_frame 函数,想要看完整代码可以访问这里。
首先,更新下 Connection 的结构体:
#![allow(unused)] fn main() { use tokio::io::BufWriter; use tokio::net::TcpStream; use bytes::BytesMut; pub struct Connection { stream: BufWriter<TcpStream>, buffer: BytesMut, } impl Connection { pub fn new(stream: TcpStream) -> Connection { Connection { stream: BufWriter::new(stream), buffer: BytesMut::with_capacity(4096), } } } }
接着来实现 write_frame 函数:
#![allow(unused)] fn main() { use tokio::io::{self, AsyncWriteExt}; use mini_redis::Frame; async fn write_frame(&mut self, frame: &Frame) -> io::Result<()> { match frame { Frame::Simple(val) => { self.stream.write_u8(b'+').await?; self.stream.write_all(val.as_bytes()).await?; self.stream.write_all(b"\r\n").await?; } Frame::Error(val) => { self.stream.write_u8(b'-').await?; self.stream.write_all(val.as_bytes()).await?; self.stream.write_all(b"\r\n").await?; } Frame::Integer(val) => { self.stream.write_u8(b':').await?; self.write_decimal(*val).await?; } Frame::Null => { self.stream.write_all(b"$-1\r\n").await?; } Frame::Bulk(val) => { let len = val.len(); self.stream.write_u8(b'$').await?; self.write_decimal(len as u64).await?; self.stream.write_all(val).await?; self.stream.write_all(b"\r\n").await?; } Frame::Array(_val) => unimplemented!(), } self.stream.flush().await; Ok(()) } }
这里使用的方法由 AsyncWriteExt 提供,它们在 TcpStream 中也有对应的函数。但是在没有缓冲区的情况下最好避免使用这种逐字节的写入方式!不然,每写入几个字节就会触发一次系统调用,写完整个数据帧可能需要几十次系统调用,可以说是丧心病狂!
write_u8写入一个字节write_all写入所有数据write_decimal由 mini-redis 提供
在函数结束前,我们还额外的调用了一次 self.stream.flush().await,原因是缓冲区可能还存在数据,因此需要手动刷一次数据:flush 的调用会将缓冲区中剩余的数据立刻写入到 socket 中。
当然,当帧比较小的时候,每写一次帧就 flush 一次的模式性能开销会比较大,此时我们可以选择在 Connection 中实现 flush 函数,然后将等帧积累多个后,再一次性在 Connection 中进行 flush。当然,对于我们的例子来说,简洁性是非常重要的,因此选了将 flush 放入到 write_frame 中。
深入 Tokio 背后的异步原理
在经过多个章节的深入学习后,Tokio 对我们来说不再是一座隐于云雾中的高山,它其实蛮简单好用的,甚至还有一丝丝的可爱!?
但从现在开始,如果想要进一步的深入 Tokio ,首先需要深入理解 async 的原理,其实我们在之前的章节已经深入学习过,这里结合 Tokio 再来回顾下。
Future
先来回顾一下 async fn 异步函数 :
#![allow(unused)] fn main() { use tokio::net::TcpStream; async fn my_async_fn() { println!("hello from async"); // 通过 .await 创建 socket 连接 let _socket = TcpStream::connect("127.0.0.1:3000").await.unwrap(); println!("async TCP operation complete"); // 关闭socket } }
接着对它进行调用获取一个返回值,再在返回值上调用 .await:
#[tokio::main] async fn main() { let what_is_this = my_async_fn(); // 上面的调用不会产生任何效果 // ... 执行一些其它代码 what_is_this.await; // 直到 .await 后,文本才被打印,socket 连接也被创建和关闭 }
在上面代码中 my_async_fn 函数为何可以惰性执行( 直到 .await 调用时才执行)?秘密就在于 async fn 声明的函数返回一个 Future。
Future 是一个实现了 std::future::Future 特征的值,该值包含了一系列异步计算过程,而这个过程直到 .await 调用时才会被执行。
std::future::Future 的定义如下所示:
#![allow(unused)] fn main() { use std::pin::Pin; use std::task::{Context, Poll}; pub trait Future { type Output; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>; } }
代码中有几个关键点:
和其它语言不同,Rust 中的 Future 不代表一个发生在后台的计算,而是 Future 就代表了计算本身,因此
Future 的所有者有责任去推进该计算过程的执行,例如通过 Future::poll 函数。听上去好像还挺复杂?但是大家不必担心,因为这些都在 Tokio 中帮你自动完成了 :)
实现 Future
下面来一起实现个五脏俱全的 Future,它将:1. 等待某个特定时间点的到来 2. 在标准输出打印文本 3. 生成一个字符串
use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::{Duration, Instant}; struct Delay { when: Instant, } // 为我们的 Delay 类型实现 Future 特征 impl Future for Delay { type Output = &'static str; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<&'static str> { if Instant::now() >= self.when { // 时间到了,Future 可以结束 println!("Hello world"); // Future 执行结束并返回 "done" 字符串 Poll::Ready("done") } else { // 目前先忽略下面这行代码 cx.waker().wake_by_ref(); Poll::Pending } } } #[tokio::main] async fn main() { let when = Instant::now() + Duration::from_millis(10); let future = Delay { when }; // 运行并等待 Future 的完成 let out = future.await; // 判断 Future 返回的字符串是否是 "done" assert_eq!(out, "done"); }
以上代码很清晰的解释了如何自定义一个 Future,并指定它如何通过 poll 一步一步执行,直到最终完成返回 "done" 字符串。
async fn 作为 Future
大家有没有注意到,上面代码我们在 main 函数中初始化一个 Future 并使用 .await 对其进行调用执行,如果你是在 fn main 中这么做,是会报错的。
原因是 .await 只能用于 async fn 函数中,因此我们将 main 函数声明成 async fn main 同时使用 #[tokio::main] 进行了标注,此时 async fn main 生成的代码类似下面:
#![allow(unused)] fn main() { use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::{Duration, Instant}; enum MainFuture { // 初始化,但永远不会被 poll State0, // 等待 `Delay` 运行,例如 `future.await` 代码行 State1(Delay), // Future 执行完成 Terminated, } impl Future for MainFuture { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { use MainFuture::*; loop { match *self { State0 => { let when = Instant::now() + Duration::from_millis(10); let future = Delay { when }; *self = State1(future); } State1(ref mut my_future) => { match Pin::new(my_future).poll(cx) { Poll::Ready(out) => { assert_eq!(out, "done"); *self = Terminated; return Poll::Ready(()); } Poll::Pending => { return Poll::Pending; } } } Terminated => { panic!("future polled after completion") } } } } } }
可以看出,编译器会将 Future 变成状态机, 其中 MainFuture 包含了 Future 可能处于的状态:从 State0 状态开始,当 poll 被调用时, Future 会尝试去尽可能的推进内部的状态,若它可以被完成时,就会返回 Poll::Ready,其中还会包含最终的输出结果。
若 Future 无法被完成,例如它所等待的资源还没有准备好,此时就会返回 Poll::Pending,该返回值会通知调用者: Future 会在稍后才能完成。
同时可以看到:当一个 Future 由其它 Future 组成时,调用外层 Future 的 poll 函数会同时调用一次内部 Future 的 poll 函数。
执行器( Excecutor )
async fn 返回 Future ,而后者需要通过被不断的 poll 才能往前推进状态,同时该 Future 还能包含其它 Future ,那么问题来了谁来负责调用最外层 Future 的 poll 函数?
回一下之前的内容,为了运行一个异步函数,我们必须使用 tokio::spawn 或 通过 #[tokio::main] 标注的 async fn main 函数。它们有一个非常重要的作用:将最外层 Future 提交给 Tokio 的执行器。该执行器负责调用 poll 函数,然后推动 Future 的执行,最终直至完成。
mini tokio
为了更好理解相关的内容,我们一起来实现一个迷你版本的 Tokio,完整的代码见这里。
先来看一段基础代码:
use std::collections::VecDeque; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use futures::task; fn main() { let mut mini_tokio = MiniTokio::new(); mini_tokio.spawn(async { let when = Instant::now() + Duration::from_millis(10); let future = Delay { when }; let out = future.await; assert_eq!(out, "done"); }); mini_tokio.run(); } struct MiniTokio { tasks: VecDeque<Task>, } type Task = Pin<Box<dyn Future<Output = ()> + Send>>; impl MiniTokio { fn new() -> MiniTokio { MiniTokio { tasks: VecDeque::new(), } } /// 生成一个 Future并放入 mini-tokio 实例的任务队列中 fn spawn<F>(&mut self, future: F) where F: Future<Output = ()> + Send + 'static, { self.tasks.push_back(Box::pin(future)); } fn run(&mut self) { let waker = task::noop_waker(); let mut cx = Context::from_waker(&waker); while let Some(mut task) = self.tasks.pop_front() { if task.as_mut().poll(&mut cx).is_pending() { self.tasks.push_back(task); } } } }
以上代码运行了一个 async 语句块 mini_tokio.spawn(async {...}), 还创建了一个 Delay 实例用于等待所需的时间。看上去相当不错,但这个实现有一个 重大缺陷:我们的执行器永远也不会休眠。执行器会持续的循环遍历所有的 Future ,然后不停的 poll 它们,但是事实上,大多数 poll 都是没有用的,因为此时 Future 并没有准备好,因此会继续返回 Poll::Pending ,最终这个循环遍历会让你的 CPU 疲于奔命,真打工人!
鉴于此,我们的 mini-tokio 只应该在 Future 准备好可以进一步运行后,才去 poll 它,例如该 Future 之前阻塞等待的资源已经准备好并可以被使用了,就可以对其进行 poll。再比如,如果一个 Future 任务在阻塞等待从 TCP socket 中读取数据,那我们只想在 socket 中有数据可以读取后才去 poll 它,而不是没事就 poll 着玩。
回到上面的代码中,mini-tokio 只应该当任务的延迟时间到了后,才去 poll 它。 为了实现这个功能,我们需要 通知 -> 运行 机制:当任务可以进一步被推进运行时,它会主动通知执行器,然后执行器再来 poll。
Waker
一切的答案都在 Waker 中,资源可以用它来通知正在等待的任务:该资源已经准备好,可以继续运行了。
再来看下 Future::poll 的定义:
#![allow(unused)] fn main() { fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>; }
Context 参数中包含有 waker()方法。该方法返回一个绑定到当前任务上的 Waker,然后 Waker 上定义了一个 wake() 方法,用于通知执行器相关的任务可以继续执行。
准确来说,当 Future 阻塞等待的资源已经准备好时(例如 socket 中有了可读取的数据),该资源可以调用 wake() 方法,来通知执行器可以继续调用该 Future 的 poll 函数来推进任务的执行。
发送 wake 通知
现在,为 Delay 添加下 Waker 支持:
#![allow(unused)] fn main() { use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use std::thread; struct Delay { when: Instant, } impl Future for Delay { type Output = &'static str; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<&'static str> { if Instant::now() >= self.when { println!("Hello world"); Poll::Ready("done") } else { // 为当前任务克隆一个 waker 的句柄 let waker = cx.waker().clone(); let when = self.when; // 生成一个计时器线程 thread::spawn(move || { let now = Instant::now(); if now < when { thread::sleep(when - now); } waker.wake(); }); Poll::Pending } } } }
此时,计时器用来模拟一个阻塞等待的资源,一旦计时结束(该资源已经准备好),资源会通过 waker.wake() 调用通知执行器我们的任务再次被调度执行了。
当然,现在的实现还较为粗糙,等会我们会来进一步优化,在此之前,先来看看如何监听这个 wake 通知。
当 Future 会返回
Poll::Pending时,一定要确保wake能被正常调用,否则会导致任务永远被挂起,再也不会被执行器poll。忘记在返回
Poll::Pending时调用wake是很多难以发现 bug 的潜在源头!
再回忆下最早实现的 Delay 代码:
#![allow(unused)] fn main() { impl Future for Delay { type Output = &'static str; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<&'static str> { if Instant::now() >= self.when { // 时间到了,Future 可以结束 println!("Hello world"); // Future 执行结束并返回 "done" 字符串 Poll::Ready("done") } else { // 目前先忽略下面这行代码 cx.waker().wake_by_ref(); Poll::Pending } } } }
在返回 Poll::Pending 之前,先调用了 cx.waker().wake_by_ref() ,由于此时我们还没有模拟计时资源,因此这里直接调用了 wake 进行通知,这样做会导致当前的 Future 被立即再次调度执行。
由此可见,这种通知的控制权是在你手里的,甚至可以像上面代码这样,还没准备好资源,就直接进行 wake 通知,但是总归意义不大,而且浪费了 CPU,因为这种 执行 -> 立即通知再调度 -> 执行 的方式会造成一个非常繁忙的循环。
处理 wake 通知
下面,让我们更新 mint-tokio 服务,让它能接收 wake 通知:当 waker.wake() 被调用后,相关联的任务会被放入执行器的队列中,然后等待执行器的调用执行。
为了实现这一点,我们将使用消息通道来排队存储这些被唤醒并等待调度的任务。有一点需要注意,从消息通道接收消息的线程(执行器所在的线程)和发送消息的线程(唤醒任务时所在的线程)可能是不同的,因此消息( Waker )必须要实现 Send和 Sync,才能跨线程使用。
关于
Send和Sync的具体讲解见这里
基于以上理由,我们选择使用来自于 crossbeam 的消息通道,因为标准库中的消息通道不是 Sync 的。在 Cargo.toml 中添加以下依赖:
crossbeam = "0.8"
再来更新下 MiniTokio 结构体:
#![allow(unused)] fn main() { use crossbeam::channel; use std::sync::Arc; struct MiniTokio { scheduled: channel::Receiver<Arc<Task>>, sender: channel::Sender<Arc<Task>>, } struct Task { // 先空着,后面会填充代码 } }
Waker 实现了 Sync 特征,同时还可以被克隆,当 wake 被调用时,任务就会被调度执行。
为了实现上述的目的,我们引入了消息通道,当 waker.wake() 函数被调用时,任务会被发送到该消息通道中:
#![allow(unused)] fn main() { use std::sync::{Arc, Mutex}; struct Task { // `Mutex` 是为了让 `Task` 实现 `Sync` 特征,它能保证同一时间只有一个线程可以访问 `Future`。 // 事实上 `Mutex` 并没有在 Tokio 中被使用,这里我们只是为了简化: Tokio 的真实代码实在太长了 :D future: Mutex<Pin<Box<dyn Future<Output = ()> + Send>>>, executor: channel::Sender<Arc<Task>>, } impl Task { fn schedule(self: &Arc<Self>) { self.executor.send(self.clone()); } } }
接下来,我们需要让 std::task::Waker 能准确的找到所需的调度函数 关联起来,对此标准库中提供了一个底层的 API std::task::RawWakerVTable 可以用于手动的访问 vtable,这种实现提供了最大的灵活性,但是需要大量 unsafe 的代码。
因此我们选择更加高级的实现:由 futures 包提供的 ArcWake 特征,只要简单实现该特征,就可以将我们的 Task 转变成一个 waker。在 Cargo.toml 中添加以下包:
futures = "0.3"
然后为我们的任务 Task 实现 ArcWake:
#![allow(unused)] fn main() { use futures::task::{self, ArcWake}; use std::sync::Arc; impl ArcWake for Task { fn wake_by_ref(arc_self: &Arc<Self>) { arc_self.schedule(); } } }
当之前的计时器线程调用 waker.wake() 时,所在的任务会被推入到消息通道中。因此接下来,我们需要实现接收端的功能,然后 MiniTokio::run() 函数中执行该任务:
#![allow(unused)] fn main() { impl MiniTokio { // 从消息通道中接收任务,然后通过 poll 来执行 fn run(&self) { while let Ok(task) = self.scheduled.recv() { task.poll(); } } /// 初始化一个新的 mini-tokio 实例 fn new() -> MiniTokio { let (sender, scheduled) = channel::unbounded(); MiniTokio { scheduled, sender } } /// 在下面函数中,通过参数传入的 future 被 `Task` 包裹起来,然后会被推入到调度队列中,当 `run` 被调用时,该 future 将被执行 fn spawn<F>(&self, future: F) where F: Future<Output = ()> + Send + 'static, { Task::spawn(future, &self.sender); } } impl Task { fn poll(self: Arc<Self>) { // 基于 Task 实例创建一个 waker, 它使用了之前的 `ArcWake` let waker = task::waker(self.clone()); let mut cx = Context::from_waker(&waker); // 没有其他线程在竞争锁时,我们将获取到目标 future let mut future = self.future.try_lock().unwrap(); // 对 future 进行 poll let _ = future.as_mut().poll(&mut cx); } // 使用给定的 future 来生成新的任务 // // 新的任务会被推到 `sender` 中,接着该消息通道的接收端就可以获取该任务,然后执行 fn spawn<F>(future: F, sender: &channel::Sender<Arc<Task>>) where F: Future<Output = ()> + Send + 'static, { let task = Arc::new(Task { future: Mutex::new(Box::pin(future)), executor: sender.clone(), }); let _ = sender.send(task); } } }
首先,我们实现了 MiniTokio::run() 函数,它会持续从消息通道中接收被唤醒的任务,然后通过 poll 来推动其继续执行。
其次,MiniTokio::new() 和 MiniTokio::spawn() 使用了消息通道而不是一个 VecDeque 。当新任务生成后,这些任务中会携带上消息通道的发送端,当任务中的资源准备就绪时,会使用该发送端将该任务放入消息通道的队列中,等待执行器 poll。
Task::poll() 函数使用 futures 包提供的 ArcWake 创建了一个 waker,后者可以用来创建 task::Context,最终该 Context 会被传给执行器调用的 poll 函数。
注意,Task::poll 和执行器调用的 poll 是完全不同的,大家别搞混了
一些遗留问题
至此,我们的程序已经差不多完成,还剩几个遗留问题需要解决下。
在异步函数中生成异步任务
之前实现 Delay Future 时,我们提到有几个问题需要解决。Rust 的异步模型允许一个 Future 在执行过程中可以跨任务迁移:
use futures::future::poll_fn; use std::future::Future; use std::pin::Pin; #[tokio::main] async fn main() { let when = Instant::now() + Duration::from_millis(10); let mut delay = Some(Delay { when }); poll_fn(move |cx| { let mut delay = delay.take().unwrap(); let res = Pin::new(&mut delay).poll(cx); assert!(res.is_pending()); tokio::spawn(async move { delay.await; }); Poll::Ready(()) }).await; }
首先,poll_fn 函数使用闭包创建了一个 Future,其次,上面代码还创建一个 Delay 实例,然后在闭包中,对其进行了一次 poll ,接着再将该 Delay 实例发送到一个新的任务,在此任务中使用 .await 进行了执行。
在例子中,Delay:poll 被调用了不止一次,且使用了不同的 Waker 实例,在这种场景下,你必须确保调用最近一次 poll 函数中的 Waker 参数中的wake方法。也就是调用最内层 poll 函数参数( Waker )上的 wake 方法。
当实现一个 Future 时,很关键的一点就是要假设每次 poll 调用都会应用到一个不同的 Waker 实例上。因此 poll 函数必须要使用一个新的 waker 去更新替代之前的 waker。
我们之前的 Delay 实现中,会在每一次 poll 调用时都生成一个新的线程。这么做问题不大,但是当 poll 调用较多时会出现明显的性能问题!一个解决方法就是记录你是否已经生成了一个线程,然后只有在没有生成时才去创建一个新的线程。但是一旦这么做,就必须确保线程的 Waker 在后续 poll 调用中被正确更新,否则你无法唤醒最近的 Waker !
这一段大家可能会看得云里雾里的,没办法,原文就饶来绕去,好在终于可以看代码了。。我们可以通过代码来解决疑惑:
#![allow(unused)] fn main() { use std::future::Future; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll, Waker}; use std::thread; use std::time::{Duration, Instant}; struct Delay { when: Instant, // 用于说明是否已经生成一个线程 // Some 代表已经生成, None 代表还没有 waker: Option<Arc<Mutex<Waker>>>, } impl Future for Delay { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { // 若这是 Future 第一次被调用,那么需要先生成一个计时器线程。 // 若不是第一次调用(该线程已在运行),那要确保已存储的 `Waker` 跟当前任务的 `waker` 匹配 if let Some(waker) = &self.waker { let mut waker = waker.lock().unwrap(); // 检查之前存储的 `waker` 是否跟当前任务的 `waker` 相匹配. // 这是必要的,原因是 `Delay Future` 的实例可能会在两次 `poll` 之间被转移到另一个任务中,然后 // 存储的 waker 被该任务进行了更新。 // 这种情况一旦发生,`Context` 包含的 `waker` 将不同于存储的 `waker`。 // 因此我们必须对存储的 `waker` 进行更新 if !waker.will_wake(cx.waker()) { *waker = cx.waker().clone(); } } else { let when = self.when; let waker = Arc::new(Mutex::new(cx.waker().clone())); self.waker = Some(waker.clone()); // 第一次调用 `poll`,生成计时器线程 thread::spawn(move || { let now = Instant::now(); if now < when { thread::sleep(when - now); } // 计时结束,通过调用 `waker` 来通知执行器 let waker = waker.lock().unwrap(); waker.wake_by_ref(); }); } // 一旦 waker 被存储且计时器线程已经开始,我们就需要检查 `delay` 是否已经完成 // 若计时已完成,则当前 Future 就可以完成并返回 `Poll::Ready` if Instant::now() >= self.when { Poll::Ready(()) } else { // 计时尚未结束,Future 还未完成,因此返回 `Poll::Pending`. // // `Future` 特征要求当 `Pending` 被返回时,那我们要确保当资源准备好时,必须调用 `waker` 以通 // 知执行器。 在我们的例子中,会通过生成的计时线程来保证 // // 如果忘记调用 waker, 那等待我们的将是深渊:该任务将被永远的挂起,无法再执行 Poll::Pending } } } }
这着实有些复杂(原文。。),但是简单来看就是:在每次 poll 调用时,都会检查 Context 中提供的 waker 和我们之前记录的 waker 是否匹配。若匹配,就什么都不用做,若不匹配,那之前存储的就必须进行更新。
Notify
我们之前证明了如何用手动编写的 waker 来实现 Delay Future。 Waker 是 Rust 异步编程的基石,因此绝大多数时候,我们并不需要直接去使用它。例如,在 Delay 的例子中, 可以使用 tokio::sync::Notify 去实现。
该 Notify 提供了一个基础的任务通知机制,它会处理这些 waker 的细节,包括确保两次 waker 的匹配:
#![allow(unused)] fn main() { use tokio::sync::Notify; use std::sync::Arc; use std::time::{Duration, Instant}; use std::thread; async fn delay(dur: Duration) { let when = Instant::now() + dur; let notify = Arc::new(Notify::new()); let notify2 = notify.clone(); thread::spawn(move || { let now = Instant::now(); if now < when { thread::sleep(when - now); } notify2.notify_one(); }); notify.notified().await; } }
当使用 Notify 后,我们就可以轻松的实现如上的 delay 函数。
总结
在看完这么长的文章后,我们来总结下,否则大家可能还会遗忘:
- 在 Rust 中,
async是惰性的,直到执行器poll它们时,才会开始执行 Waker是Future被执行的关键,它可以链接起Future任务和执行器- 当资源没有准备时,会返回一个
Poll::Pending - 当资源准备好时,会通过
waker.wake发出通知 - 执行器会收到通知,然后调度该任务继续执行,此时由于资源已经准备好,因此任务可以顺利往前推进了
select!
在实际使用时,一个重要的场景就是同时等待多个异步操作的结果,并且对其结果进行进一步处理,在本章节,我们来看看,强大的 select! 是如何帮助咱们更好的控制多个异步操作并发执行的。
tokio::select!
select! 允许同时等待多个计算操作,然后当其中一个操作完成时就退出等待:
use tokio::sync::oneshot; #[tokio::main] async fn main() { let (tx1, rx1) = oneshot::channel(); let (tx2, rx2) = oneshot::channel(); tokio::spawn(async { let _ = tx1.send("one"); }); tokio::spawn(async { let _ = tx2.send("two"); }); tokio::select! { val = rx1 => { println!("rx1 completed first with {:?}", val); } val = rx2 => { println!("rx2 completed first with {:?}", val); } } // 任何一个 select 分支结束后,都会继续执行接下来的代码 }
这里用到了两个 oneshot 消息通道,虽然两个操作的创建在代码上有先后顺序,但在实际执行时却不这样。因此, select 在从两个通道阻塞等待接收消息时,rx1 和 rx2 都有可能被先打印出来。
需要注意,任何一个 select 分支完成后,都会继续执行后面的代码,没被执行的分支会被丢弃( dropped )。
取消
对于 Async Rust 来说,释放( drop )掉一个 Future 就意味着取消任务。从上一章节可以得知, async 操作会返回一个 Future,而后者是惰性的,直到被 poll 调用时,才会被执行。一旦 Future 被释放,那操作将无法继续,因为所有相关的状态都被释放。
对于 Tokio 的 oneshot 的接收端来说,它在被释放时会发送一个关闭通知到发送端,因此发送端可以通过释放任务的方式来终止正在执行的任务。
use tokio::sync::oneshot; async fn some_operation() -> String { // 在这里执行一些操作... } #[tokio::main] async fn main() { let (mut tx1, rx1) = oneshot::channel(); let (tx2, rx2) = oneshot::channel(); tokio::spawn(async { // 等待 `some_operation` 的完成 // 或者处理 `oneshot` 的关闭通知 tokio::select! { val = some_operation() => { let _ = tx1.send(val); } _ = tx1.closed() => { // 收到了发送端发来的关闭信号 // `select` 即将结束,此时,正在进行的 `some_operation()` 任务会被取消,任务自动完成, // tx1 被释放 } } }); tokio::spawn(async { let _ = tx2.send("two"); }); tokio::select! { val = rx1 => { println!("rx1 completed first with {:?}", val); } val = rx2 => { println!("rx2 completed first with {:?}", val); } } }
上面代码的重点就在于 tx1.closed 所在的分支,一旦发送端被关闭,那该分支就会被执行,然后 select 会退出,并清理掉还没执行的第一个分支 val = some_operation() ,这其中 some_operation 返回的 Future 也会被清理,根据之前的内容,Future 被清理那相应的任务会立即取消,因此 some_operation 会被取消,不再执行。
Future 的实现
为了更好的理解 select 的工作原理,我们来看看如果使用 Future 该如何实现。当然,这里是一个简化版本,在实际中,select! 会包含一些额外的功能,例如一开始会随机选择一个分支进行 poll。
use tokio::sync::oneshot; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; struct MySelect { rx1: oneshot::Receiver<&'static str>, rx2: oneshot::Receiver<&'static str>, } impl Future for MySelect { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { if let Poll::Ready(val) = Pin::new(&mut self.rx1).poll(cx) { println!("rx1 completed first with {:?}", val); return Poll::Ready(()); } if let Poll::Ready(val) = Pin::new(&mut self.rx2).poll(cx) { println!("rx2 completed first with {:?}", val); return Poll::Ready(()); } Poll::Pending } } #[tokio::main] async fn main() { let (tx1, rx1) = oneshot::channel(); let (tx2, rx2) = oneshot::channel(); // 使用 tx1 和 tx2 MySelect { rx1, rx2, }.await; }
MySelect 包含了两个分支中的 Future,当它被 poll 时,第一个分支会先执行。如果执行完成,那取出的值会被使用,然后 MySelect 也随之结束。而另一个分支对应的 Future 会被释放掉,对应的操作也会被取消。
还记得上一章节中很重要的一段话吗?
当一个
Future返回Poll::Pending时,它必须确保会在某一个时刻通过Waker来唤醒,不然该Future将永远地被挂起
但是仔细观察我们之前的代码,里面并没有任何的 wake 调用!事实上,这是因为参数 cx 被传入了内层的 poll 调用。 只要内部的 Future 实现了唤醒并且返回了 Poll::Pending,那 MySelect 也等于实现了唤醒!
语法
目前来说,select! 最多可以支持 64 个分支,每个分支形式如下:
#![allow(unused)] fn main() { <模式> = <async 表达式> => <结果处理>, }
当 select 宏开始执行后,所有的分支会开始并发的执行。当任何一个表达式完成时,会将结果跟模式进行匹配。若匹配成功,则剩下的表达式会被释放。
最常用的模式就是用变量名去匹配表达式返回的值,然后该变量就可以在结果处理环节使用。
如果当前的模式不能匹配,剩余的 async 表达式将继续并发的执行,直到下一个完成。
由于 select! 使用的是一个 async 表达式,因此我们可以定义一些更复杂的计算。
例如从在分支中进行 TCP 连接:
use tokio::net::TcpStream; use tokio::sync::oneshot; #[tokio::main] async fn main() { let (tx, rx) = oneshot::channel(); // 生成一个任务,用于向 oneshot 发送一条消息 tokio::spawn(async move { tx.send("done").unwrap(); }); tokio::select! { socket = TcpStream::connect("localhost:3465") => { println!("Socket connected {:?}", socket); } msg = rx => { println!("received message first {:?}", msg); } } }
再比如,在分支中进行 TCP 监听:
use tokio::net::TcpListener; use tokio::sync::oneshot; use std::io; #[tokio::main] async fn main() -> io::Result<()> { let (tx, rx) = oneshot::channel(); tokio::spawn(async move { tx.send(()).unwrap(); }); let mut listener = TcpListener::bind("localhost:3465").await?; tokio::select! { _ = async { loop { let (socket, _) = listener.accept().await?; tokio::spawn(async move { process(socket) }); } // 给予 Rust 类型暗示 Ok::<_, io::Error>(()) } => {} _ = rx => { println!("terminating accept loop"); } } Ok(()) }
分支中接收连接的循环会一直运行,直到遇到错误才停止,或者当 rx 中有值时,也会停止。 _ 表示我们并不关心这个值,这样使用唯一的目的就是为了结束第一分支中的循环。
返回值
select! 还能返回一个值:
async fn computation1() -> String { // .. 计算 } async fn computation2() -> String { // .. 计算 } #[tokio::main] async fn main() { let out = tokio::select! { res1 = computation1() => res1, res2 = computation2() => res2, }; println!("Got = {}", out); }
需要注意的是,此时 select! 的所有分支必须返回一样的类型,否则编译器会报错!
错误传播
在 Rust 中使用 ? 可以对错误进行传播,但是在 select! 中,? 如何工作取决于它是在分支中的 async 表达式使用还是在结果处理的代码中使用:
- 在分支中
async表达式使用会将该表达式的结果变成一个Result - 在结果处理中使用,会将错误直接传播到
select!之外
use tokio::net::TcpListener; use tokio::sync::oneshot; use std::io; #[tokio::main] async fn main() -> io::Result<()> { // [设置 `rx` oneshot 消息通道] let listener = TcpListener::bind("localhost:3465").await?; tokio::select! { res = async { loop { let (socket, _) = listener.accept().await?; tokio::spawn(async move { process(socket) }); } Ok::<_, io::Error>(()) } => { res?; } _ = rx => { println!("terminating accept loop"); } } Ok(()) }
listener.accept().await? 是分支表达式中的 ?,因此它会将表达式的返回值变成 Result 类型,然后赋予给 res 变量。
与之不同的是,结果处理中的 res?; 会让 main 函数直接结束并返回一个 Result,可以看出,这里 ? 的用法跟我们平时的用法并无区别。
模式匹配
既然是模式匹配,我们需要再来回忆下 select! 的分支语法形式:
#![allow(unused)] fn main() { <模式> = <async 表达式> => <结果处理>, }
迄今为止,我们只用了变量绑定的模式,事实上,任何 Rust 模式都可以在此处使用。
use tokio::sync::mpsc; #[tokio::main] async fn main() { let (mut tx1, mut rx1) = mpsc::channel(128); let (mut tx2, mut rx2) = mpsc::channel(128); tokio::spawn(async move { // 用 tx1 和 tx2 干一些不为人知的事 }); tokio::select! { Some(v) = rx1.recv() => { println!("Got {:?} from rx1", v); } Some(v) = rx2.recv() => { println!("Got {:?} from rx2", v); } else => { println!("Both channels closed"); } } }
上面代码中,rx 通道关闭后,recv() 方法会返回一个 None,可以看到没有任何模式能够匹配这个 None,那为何不会报错?秘密就在于 else 上:当使用模式去匹配分支时,若之前的所有分支都无法被匹配,那 else 分支将被执行。
借用
当在 Tokio 中生成( spawn )任务时,其 async 语句块必须拥有其中数据的所有权。而 select! 并没有这个限制,它的每个分支表达式可以直接借用数据,然后进行并发操作。只要遵循 Rust 的借用规则,多个分支表达式可以不可变的借用同一个数据,或者在一个表达式可变的借用某个数据。
来看个例子,在这里我们同时向两个 TCP 目标发送同样的数据:
#![allow(unused)] fn main() { use tokio::io::AsyncWriteExt; use tokio::net::TcpStream; use std::io; use std::net::SocketAddr; async fn race( data: &[u8], addr1: SocketAddr, addr2: SocketAddr ) -> io::Result<()> { tokio::select! { Ok(_) = async { let mut socket = TcpStream::connect(addr1).await?; socket.write_all(data).await?; Ok::<_, io::Error>(()) } => {} Ok(_) = async { let mut socket = TcpStream::connect(addr2).await?; socket.write_all(data).await?; Ok::<_, io::Error>(()) } => {} else => {} }; Ok(()) } }
这里其实有一个很有趣的题外话,由于 TCP 连接过程是在模式中发生的,因此当某一个连接过程失败后,它通过 ? 返回的 Err 类型并无法匹配 Ok,因此另一个分支会继续被执行,继续连接。
如果你把连接过程放在了结果处理中,那连接失败会直接从 race 函数中返回,而不是继续执行另一个分支中的连接!
还有一个非常重要的点,借用规则在分支表达式和结果处理中存在很大的不同。例如上面代码中,我们在两个分支表达式中分别对 data 做了不可变借用,这当然 ok,但是若是两次可变借用,那编译器会立即进行报错。但是转折来了:当在结果处理中进行两次可变借用时,却不会报错,大家可以思考下为什么,提示下:思考下分支在执行完成后会发生什么?
use tokio::sync::oneshot; #[tokio::main] async fn main() { let (tx1, rx1) = oneshot::channel(); let (tx2, rx2) = oneshot::channel(); let mut out = String::new(); tokio::spawn(async move { }); tokio::select! { _ = rx1 => { out.push_str("rx1 completed"); } _ = rx2 => { out.push_str("rx2 completed"); } } println!("{}", out); }
例如以上代码,就在两个分支的结果处理中分别进行了可变借用,并不会报错。原因就在于:select!会保证只有一个分支的结果处理会被运行,然后在运行结束后,另一个分支会被直接丢弃。
循环
来看看该如何在循环中使用 select!,顺便说一句,跟循环一起使用是最常见的使用方式。
use tokio::sync::mpsc; #[tokio::main] async fn main() { let (tx1, mut rx1) = mpsc::channel(128); let (tx2, mut rx2) = mpsc::channel(128); let (tx3, mut rx3) = mpsc::channel(128); loop { let msg = tokio::select! { Some(msg) = rx1.recv() => msg, Some(msg) = rx2.recv() => msg, Some(msg) = rx3.recv() => msg, else => { break } }; println!("Got {}", msg); } println!("All channels have been closed."); }
在循环中使用 select! 最大的不同就是,当某一个分支执行完成后,select! 会继续循环等待并执行下一个分支,直到所有分支最终都完成,最终匹配到 else 分支,然后通过 break 跳出循环。
老生常谈的一句话:select! 中哪个分支先被执行是无法确定的,因此不要依赖于分支执行的顺序!想象一下,在异步编程场景,若 select! 按照分支的顺序来执行会如何:若 rx1 中总是有数据,那每次循环都只会去处理第一个分支,后面两个分支永远不会被执行。
恢复之前的异步操作
async fn action() { // 一些异步逻辑 } #[tokio::main] async fn main() { let (mut tx, mut rx) = tokio::sync::mpsc::channel(128); let operation = action(); tokio::pin!(operation); loop { tokio::select! { _ = &mut operation => break, Some(v) = rx.recv() => { if v % 2 == 0 { break; } } } } }
在上面代码中,我们没有直接在 select! 分支中调用 action() ,而是在 loop 循环外面先将 action() 赋值给 operation,因此 operation 是一个 Future。
重点来了,在 select! 循环中,我们使用了一个奇怪的语法 &mut operation,大家想象一下,如果不加 &mut 会如何?答案是,每一次循环调用的都是一次全新的 action()调用,但是当加了 &mut operatoion 后,每一次循环调用就变成了对同一次 action() 的调用。也就是我们实现了在每次循环中恢复了之前的异步操作!
select! 的另一个分支从消息通道收取消息,一旦收到值是偶数,就跳出循环,否则就继续循环。
还有一个就是我们使用了 tokio::pin!,具体的细节这里先不介绍,值得注意的点是:如果要在一个引用上使用 .await,那么引用的值就必须是不能移动的或者实现了 Unpin,关于 Pin 和 Unpin 可以参见这里。
一旦移除 tokio::pin! 所在行的代码,然后试图编译,就会获得以下错误:
error[E0599]: no method named `poll` found for struct
`std::pin::Pin<&mut &mut impl std::future::Future>`
in the current scope
--> src/main.rs:16:9
|
16 | / tokio::select! {
17 | | _ = &mut operation => break,
18 | | Some(v) = rx.recv() => {
19 | | if v % 2 == 0 {
... |
22 | | }
23 | | }
| |_________^ method not found in
| `std::pin::Pin<&mut &mut impl std::future::Future>`
|
= note: the method `poll` exists but the following trait bounds
were not satisfied:
`impl std::future::Future: std::marker::Unpin`
which is required by
`&mut impl std::future::Future: std::future::Future`
虽然我们已经学了很多关于 Future 的知识,但是这个错误依然不太好理解。但是它不难解决:当你试图在一个引用上调用 .await 然后遇到了 Future 未实现 这种错误时,往往只需要将对应的 Future 进行固定即可: tokio::pin!(operation);。
修改一个分支
下面一起来看一个稍微复杂一些的 loop 循环,首先,我们拥有:
- 一个消息通道可以传递
i32类型的值 - 定义在
i32值上的一个异步操作
想要实现的逻辑是:
- 在消息通道中等待一个偶数出现
- 使用该偶数作为输入来启动一个异步操作
- 等待异步操作完成,与此同时监听消息通道以获取更多的偶数
- 若在异步操作完成前一个新的偶数到来了,终止当前的异步操作,然后接着使用新的偶数开始异步操作
async fn action(input: Option<i32>) -> Option<String> { // 若 input(输入)是None,则返回 None // 事实上也可以这么写: `let i = input?;` let i = match input { Some(input) => input, None => return None, }; // 这里定义一些逻辑 } #[tokio::main] async fn main() { let (mut tx, mut rx) = tokio::sync::mpsc::channel(128); let mut done = false; let operation = action(None); tokio::pin!(operation); tokio::spawn(async move { let _ = tx.send(1).await; let _ = tx.send(3).await; let _ = tx.send(2).await; }); loop { tokio::select! { res = &mut operation, if !done => { done = true; if let Some(v) = res { println!("GOT = {}", v); return; } } Some(v) = rx.recv() => { if v % 2 == 0 { // `.set` 是 `Pin` 上定义的方法 operation.set(action(Some(v))); done = false; } } } } }
当第一次循环开始时, 第一个分支会立即完成,因为 operation 的参数是 None。当第一个分支执行完成时,done 会变成 true,此时第一个分支的条件将无法被满足,开始执行第二个分支。
当第二个分支收到一个偶数时,done 会被修改为 false,且 operation 被设置了值。 此后再一次循环时,第一个分支会被执行,且 operation 返回一个 Some(2),因此会触发 return ,最终结束循环并返回。
这段代码引入了一个新的语法: if !done,在解释之前,先看看去掉后会如何:
thread 'main' panicked at '`async fn` resumed after completion', src/main.rs:1:55
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
'`async fn` resumed after completion' 错误的含义是:async fn 异步函数在完成后,依然被恢复了(继续使用)。
回到例子中来,这个错误是由于 operation 在它已经调用完成后依然被使用。通常来说,当使用 .await 后,调用 .await 的值会被消耗掉,因此并不存在这个问题。但是在这例子中,我们在引用上调用 .await,因此之后该引用依然可以被使用。
为了避免这个问题,需要在第一个分支的 operation 完成后禁止再使用该分支。这里的 done 的引入就很好的解决了问题。对于 select! 来说 if !done 的语法被称为预条件( precondition ),该条件会在分支被 .await 执行前进行检查。
那大家肯定有疑问了,既然 operation 不能再被调用了,我们该如何在有偶数值时,再回到第一个分支对其进行调用呢?答案就是 operation.set(action(Some(v)));,该操作会重新使用新的参数设置 operation。
spawn 和 select! 的一些不同
学到现在,相信大家对于 tokio::spawn 和 select! 已经非常熟悉,它们的共同点就是都可以并发的运行异步操作。
然而它们使用的策略大相径庭。
tokio::spawn 函数会启动新的任务来运行一个异步操作,每个任务都是一个独立的对象可以单独被 Tokio 调度运行,因此两个不同的任务的调度都是独立进行的,甚至于它们可能会运行在两个不同的操作系统线程上。鉴于此,生成的任务和生成的线程有一个相同的限制:不允许对外部环境中的值进行借用。
而 select! 宏就不一样了,它在同一个任务中并发运行所有的分支。正是因为这样,在同一个任务中,这些分支无法被同时运行。 select! 宏在单个任务中实现了多路复用的功能。
Stream
大家有没有想过, Rust 中的迭代器在迭代时能否异步进行?若不可以,是不是有相应的解决方案?
以上的问题其实很重要,因为在实际场景中,迭代一个集合,然后异步的去执行是很常见的需求,好在 Tokio 为我们提供了 stream,我们可以在异步函数中对其进行迭代,甚至和迭代器 Iterator 一样,stream 还能使用适配器,例如 map ! Tokio 在 StreamExt 特征上定义了常用的适配器。
要使用 stream ,目前还需要手动引入对应的包:
#![allow(unused)] fn main() { tokio-stream = "0.1" }
stream 没有放在
tokio包的原因在于标准库中的Stream特征还没有稳定,一旦稳定后,stream将移动到tokio中来
迭代
目前, Rust 语言还不支持异步的 for 循环,因此我们需要 while let 循环和 StreamExt::next() 一起使用来实现迭代的目的:
use tokio_stream::StreamExt; #[tokio::main] async fn main() { let mut stream = tokio_stream::iter(&[1, 2, 3]); while let Some(v) = stream.next().await { println!("GOT = {:?}", v); } }
和迭代器 Iterator 类似,next() 方法返回一个 Option<T>,其中 T 是从 stream 中获取的值的类型。若收到 None 则意味着 stream 迭代已经结束。
mini-redis 广播
下面我们来实现一个复杂一些的 mini-redis 客户端,完整代码见这里。
在开始之前,首先启动一下完整的 mini-redis 服务器端:
$ mini-redis-server
use tokio_stream::StreamExt; use mini_redis::client; async fn publish() -> mini_redis::Result<()> { let mut client = client::connect("127.0.0.1:6379").await?; // 发布一些数据 client.publish("numbers", "1".into()).await?; client.publish("numbers", "two".into()).await?; client.publish("numbers", "3".into()).await?; client.publish("numbers", "four".into()).await?; client.publish("numbers", "five".into()).await?; client.publish("numbers", "6".into()).await?; Ok(()) } async fn subscribe() -> mini_redis::Result<()> { let client = client::connect("127.0.0.1:6379").await?; let subscriber = client.subscribe(vec!["numbers".to_string()]).await?; let messages = subscriber.into_stream(); tokio::pin!(messages); while let Some(msg) = messages.next().await { println!("got = {:?}", msg); } Ok(()) } #[tokio::main] async fn main() -> mini_redis::Result<()> { tokio::spawn(async { publish().await }); subscribe().await?; println!("DONE"); Ok(()) }
上面生成了一个异步任务专门用于发布消息到 min-redis 服务器端的 numbers 消息通道中。然后,在 main 中,我们订阅了 numbers 消息通道,并且打印从中接收到的消息。
还有几点值得注意的:
into_stream会将Subscriber变成一个stream- 在
stream上调用next方法要求该stream被固定住(pinned),因此需要调用tokio::pin!
关于 Pin 的详细解读,可以阅读这篇文章
大家可以去掉 pin! 的调用,然后观察下报错,若以后你遇到这种错误,可以尝试使用下 pin!。
此时,可以运行下我们的客户端代码看看效果(别忘了先启动前面提到的 mini-redis 服务端):
got = Ok(Message { channel: "numbers", content: b"1" })
got = Ok(Message { channel: "numbers", content: b"two" })
got = Ok(Message { channel: "numbers", content: b"3" })
got = Ok(Message { channel: "numbers", content: b"four" })
got = Ok(Message { channel: "numbers", content: b"five" })
got = Ok(Message { channel: "numbers", content: b"6" })
在了解了 stream 的基本用法后,我们再来看看如何使用适配器来扩展它。
适配器
在前面章节中,我们了解了迭代器有两种适配器:
- 迭代器适配器,会将一个迭代器转变成另一个迭代器,例如
map,filter等 - 消费者适配器,会消费掉一个迭代器,最终生成一个值,例如
collect可以将迭代器收集成一个集合
与迭代器类似,stream 也有适配器,例如一个 stream 适配器可以将一个 stream 转变成另一个 stream ,例如 map、take 和 filter。
在之前的客户端中,subscribe 订阅一直持续下去,直到程序被关闭。现在,让我们来升级下,让它在收到三条消息后就停止迭代,最终结束。
#![allow(unused)] fn main() { let messages = subscriber .into_stream() .take(3); }
这里关键就在于 take 适配器,它会限制 stream 只能生成最多 n 条消息。运行下看看结果:
got = Ok(Message { channel: "numbers", content: b"1" })
got = Ok(Message { channel: "numbers", content: b"two" })
got = Ok(Message { channel: "numbers", content: b"3" })
程序终于可以正常结束了。现在,让我们过滤 stream 中的消息,只保留数字类型的值:
#![allow(unused)] fn main() { let messages = subscriber .into_stream() .filter(|msg| match msg { Ok(msg) if msg.content.len() == 1 => true, _ => false, }) .take(3); }
运行后输出:
got = Ok(Message { channel: "numbers", content: b"1" })
got = Ok(Message { channel: "numbers", content: b"3" })
got = Ok(Message { channel: "numbers", content: b"6" })
需要注意的是,适配器的顺序非常重要,.filter(...).take(3) 和 .take(3).filter(...) 的结果可能大相径庭,大家可以自己尝试下。
现在,还有一件事要做,咱们的消息被不太好看的 Ok(...) 所包裹,现在通过 map 适配器来简化下:
#![allow(unused)] fn main() { let messages = subscriber .into_stream() .filter(|msg| match msg { Ok(msg) if msg.content.len() == 1 => true, _ => false, }) .map(|msg| msg.unwrap().content) .take(3); }
注意到 msg.unwrap 了吗?大家可能会以为我们是出于示例的目的才这么用,实际上并不是,由于 filter 的先执行, map 中的 msg 只能是 Ok(...),因此 unwrap 非常安全。
got = b"1"
got = b"3"
got = b"6"
还有一点可以改进的地方:当 filter 和 map 一起使用时,你往往可以用一个统一的方法来实现 filter_map。
#![allow(unused)] fn main() { let messages = subscriber .into_stream() .filter_map(|msg| match msg { Ok(msg) if msg.content.len() == 1 => Some(msg.content), _ => None, }) .take(3); }
想要学习更多的适配器,可以看看 StreamExt 特征。
实现 Stream 特征
如果大家还没忘记 Future 特征,那 Stream 特征相信你也会很快记住,因为它们非常类似:
#![allow(unused)] fn main() { use std::pin::Pin; use std::task::{Context, Poll}; pub trait Stream { type Item; fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll<Option<Self::Item>>; fn size_hint(&self) -> (usize, Option<usize>) { (0, None) } } }
Stream::poll_next() 函数跟 Future::poll 很相似,区别就是前者为了从 stream 收到多个值需要重复的进行调用。 就像在 深入async 章节提到的那样,当一个 stream 没有做好返回一个值的准备时,它将返回一个 Poll::Pending ,同时将任务的 waker 进行注册。一旦 stream 准备好后, waker 将被调用。
通常来说,如果想要手动实现一个 Stream,需要组合 Future 和其它 Stream。下面,还记得在深入async 中构建的 Delay Future 吗?现在让我们来更进一步,将它转换成一个 stream,每 10 毫秒生成一个值,总共生成 3 次:
#![allow(unused)] fn main() { use tokio_stream::Stream; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; struct Interval { rem: usize, delay: Delay, } impl Stream for Interval { type Item = (); fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<()>> { if self.rem == 0 { // 去除计时器实现 return Poll::Ready(None); } match Pin::new(&mut self.delay).poll(cx) { Poll::Ready(_) => { let when = self.delay.when + Duration::from_millis(10); self.delay = Delay { when }; self.rem -= 1; Poll::Ready(Some(())) } Poll::Pending => Poll::Pending, } } } }
async-stream
手动实现 Stream 特征实际上是相当麻烦的事,不幸地是,Rust 语言的 async/await 语法目前还不能用于定义 stream,虽然相关的工作已经在进行中。
作为替代方案,async-stream 包提供了一个 stream! 宏,它可以将一个输入转换成 stream,使用这个包,上面的代码可以这样实现:
#![allow(unused)] fn main() { use async_stream::stream; use std::time::{Duration, Instant}; stream! { let mut when = Instant::now(); for _ in 0..3 { let delay = Delay { when }; delay.await; yield (); when += Duration::from_millis(10); } } }
嗯,看上去还是相当不错的,代码可读性大幅提升!
是不是发现了一个关键字 yield ,他是用来配合生成器使用的。详见原文
优雅的关闭
如果你的服务是一个小说阅读网站,那大概率用不到优雅关闭的,简单粗暴的关闭服务器,然后用户再次请求时获取一个错误就是了。但如果是一个 web 服务或数据库服务呢?当前的连接很可能在做着重要的事情,一旦关闭会导致数据的丢失甚至错误,此时,我们就需要优雅的关闭(graceful shutdown)了。
要让一个异步应用优雅的关闭往往需要做到 3 点:
- 找出合适的关闭时机
- 通知程序的每一个子部分开始关闭
- 在主线程等待各个部分的关闭结果
在本文的下面部分,我们一起来看看该如何做到这三点。如果想要进一步了解在真实项目中该如何使用,大家可以看看 mini-redis 的完整代码实现,特别是 src/server.rs 和 src/shutdown.rs。
找出合适的关闭时机
一般来说,何时关闭是取决于应用自身的,但是一个常用的关闭准则就是当应用收到来自于操作系统的关闭信号时。例如通过 ctrl + c 来关闭正在运行的命令行程序。
为了检测来自操作系统的关闭信号,Tokio 提供了一个 tokio::signal::ctrl_c 函数,它将一直睡眠直到收到对应的信号:
use tokio::signal; #[tokio::main] async fn main() { // ... spawn application as separate task ... // 在一个单独的任务中处理应用逻辑 match signal::ctrl_c().await { Ok(()) => {}, Err(err) => { eprintln!("Unable to listen for shutdown signal: {}", err); }, } // 发送关闭信号给应用所在的任务,然后等待 }
通知程序的每一个部分开始关闭
大家看到这个标题,不知道会想到用什么技术来解决问题,反正我首先想到的是,真的很像广播哎。。
事实上也是如此,最常见的通知程序各个部分关闭的方式就是使用一个广播消息通道。关于如何实现,其实也不复杂:应用中的每个任务都持有一个广播消息通道的接收端,当消息被广播到该通道时,每个任务都可以收到该消息,并关闭自己:
#![allow(unused)] fn main() { let next_frame = tokio::select! { res = self.connection.read_frame() => res?, _ = self.shutdown.recv() => { // 当收到关闭信号后,直接从 `select!` 返回,此时 `select!` 中的另一个分支会自动释放,其中的任务也会结束 return Ok(()); } }; }
在 mini-redis 中,当收到关闭消息时,任务会立即结束,但在实际项目中,这种方式可能会过于理想,例如当我们向文件或数据库写入数据时,立刻终止任务可能会导致一些无法预料的错误,因此,在结束前做一些收尾工作会是非常好的选择。
除此之外,还有两点值得注意:
- 将广播消息通道作为结构体的一个字段是相当不错的选择, 例如这个例子
- 还可以使用
watch channel实现同样的效果,与之前的方式相比,这两种方法并没有太大的区别
等待各个部分的结束
在之前章节,我们讲到过一个 mpsc 消息通道有一个重要特性:当所有发送端都 drop 时,消息通道会自动关闭,此时继续接收消息就会报错。
大家发现没?这个特性特别适合优雅关闭的场景:主线程持有消息通道的接收端,然后每个代码部分拿走一个发送端,当该部分结束时,就 drop 掉发送端,因此所有发送端被 drop 也就意味着所有的部分都已关闭,此时主线程的接收端就会收到错误,进而结束。
use tokio::sync::mpsc::{channel, Sender}; use tokio::time::{sleep, Duration}; #[tokio::main] async fn main() { let (send, mut recv) = channel(1); for i in 0..10 { tokio::spawn(some_operation(i, send.clone())); } // 等待各个任务的完成 // // 我们需要 drop 自己的发送端,因为等下的 `recv()` 调用会阻塞, 如果不 `drop` ,那发送端就无法被全部关闭 // `recv` 也将永远无法结束,这将陷入一个类似死锁的困境 drop(send); // 当所有发送端都超出作用域被 `drop` 时 (当前的发送端并不是因为超出作用域被 `drop` 而是手动 `drop` 的) // `recv` 调用会返回一个错误 let _ = recv.recv().await; } async fn some_operation(i: u64, _sender: Sender<()>) { sleep(Duration::from_millis(100 * i)).await; println!("Task {} shutting down.", i); // 发送端超出作用域,然后被 `drop` }
关于忘记 drop 本身持有的发送端进而导致 bug 的问题,大家可以看看这篇文章。
异步跟同步共存
一些异步程序例如 tokio 指南 章节中的绝大多数例子,它们整个程序都是异步的,包括程序入口 main 函数:
#[tokio::main] async fn main() { println!("Hello world"); }
在一些场景中,你可能只想在异步程序中运行一小部分同步代码,这种需求可以考虑下 spawn_blocking。
但是在很多场景中,我们只想让程序的某一个部分成为异步的,也许是因为同步代码更好实现,又或许是同步代码可读性、兼容性都更好。例如一个 GUI 应用可能想要让 UI 相关的代码在主线程中,然后通过另一个线程使用 tokio 的运行时来处理一些异步任务。
因此本章节的目标很纯粹:如何在同步代码中使用一小部分异步代码。
#[tokio::main] 的展开
在 Rust 中, main 函数不能是异步的,有同学肯定不愿意了,我们在之前章节..不对,就在开头,你还用到了 async fn main 的声明方式,怎么就不能异步了呢?
其实,#[tokio::main] 该宏仅仅是提供语法糖,目的是让大家可以更简单、更一致的去写异步代码,它会将你写下的async fn main 函数替换为:
fn main() { tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .unwrap() .block_on(async { println!("Hello world"); }) }
注意到上面的 block_on 方法了嘛?在我们自己的同步代码中,可以使用它开启一个 async/await 世界。
mini-redis 的同步接口
在下面,我们将一起构建一个同步的 mini-redis ,为了实现这一点,需要将 Runtime 对象存储起来,然后利用上面提到的 block_on 方法。
首先,创建一个文件 src/blocking_client.rs,然后使用下面代码将异步的 Client 结构体包裹起来:
#![allow(unused)] fn main() { use tokio::net::ToSocketAddrs; use tokio::runtime::Runtime; pub use crate::client::Message; /// 建立到 redis 服务端的连接 pub struct BlockingClient { /// 之前实现的异步客户端 `Client` inner: crate::client::Client, /// 一个 `current_thread` 模式的 `tokio` 运行时, /// 使用阻塞的方式来执行异步客户端 `Client` 上的操作 rt: Runtime, } pub fn connect<T: ToSocketAddrs>(addr: T) -> crate::Result<BlockingClient> { // 构建一个 tokio 运行时: Runtime let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; // 使用运行时来调用异步的连接方法 let inner = rt.block_on(crate::client::connect(addr))?; Ok(BlockingClient { inner, rt }) } }
在这里,我们使用了一个构造器函数用于在同步代码中执行异步的方法:使用 Runtime 上的 block_on 方法来执行一个异步方法并返回结果。
有一个很重要的点,就是我们还使用了 current_thread 模式的运行时。这个可不常见,原因是异步程序往往要利用多线程的威力来实现更高的吞吐性能,相对应的模式就是 multi_thread,该模式会生成多个运行在后台的线程,它们可以高效的实现多个任务的同时并行处理。
但是对于我们的使用场景来说,在同一时间点只需要做一件事,无需并行处理,多个线程并不能帮助到任何事情,因此 current_thread 此时成为了最佳的选择。
在构建 Runtime 的过程中还有一个 enable_all 方法调用,它可以开启 Tokio 运行时提供的 IO 和定时器服务。
由于
current_thread运行时并不生成新的线程,只是运行在已有的主线程上,因此只有当block_on被调用后,该运行时才能执行相应的操作。一旦block_on返回,那运行时上所有生成的任务将再次冻结,直到block_on的再次调用。如果这种模式不符合使用场景的需求,那大家还是需要用
multi_thread运行时来代替。事实上,在 tokio 之前的章节中,我们默认使用的就是multi_thread模式。
#![allow(unused)] fn main() { use bytes::Bytes; use std::time::Duration; impl BlockingClient { pub fn get(&mut self, key: &str) -> crate::Result<Option<Bytes>> { self.rt.block_on(self.inner.get(key)) } pub fn set(&mut self, key: &str, value: Bytes) -> crate::Result<()> { self.rt.block_on(self.inner.set(key, value)) } pub fn set_expires( &mut self, key: &str, value: Bytes, expiration: Duration, ) -> crate::Result<()> { self.rt.block_on(self.inner.set_expires(key, value, expiration)) } pub fn publish(&mut self, channel: &str, message: Bytes) -> crate::Result<u64> { self.rt.block_on(self.inner.publish(channel, message)) } } }
这代码看上去挺长,实际上很简单,通过 block_on 将异步形式的 Client 的方法变成同步调用的形式。例如 BlockingClient 的 get 方法实际上是对内部的异步 get 方法的同步调用。
与上面的平平无奇相比,下面的代码将更有趣,因为它将 Client 转变成一个 Subscriber 对象:
#![allow(unused)] fn main() { /// 下面的客户端可以进入 pub/sub (发布/订阅) 模式 /// /// 一旦客户端订阅了某个消息通道,那就只能执行 pub/sub 相关的命令。 /// 将`BlockingClient` 类型转换成 `BlockingSubscriber` 是为了防止非 `pub/sub` 方法被调用 pub struct BlockingSubscriber { /// 异步版本的 `Subscriber` inner: crate::client::Subscriber, /// 一个 `current_thread` 模式的 `tokio` 运行时, /// 使用阻塞的方式来执行异步客户端 `Client` 上的操作 rt: Runtime, } impl BlockingClient { pub fn subscribe(self, channels: Vec<String>) -> crate::Result<BlockingSubscriber> { let subscriber = self.rt.block_on(self.inner.subscribe(channels))?; Ok(BlockingSubscriber { inner: subscriber, rt: self.rt, }) } } impl BlockingSubscriber { pub fn get_subscribed(&self) -> &[String] { self.inner.get_subscribed() } pub fn next_message(&mut self) -> crate::Result<Option<Message>> { self.rt.block_on(self.inner.next_message()) } pub fn subscribe(&mut self, channels: &[String]) -> crate::Result<()> { self.rt.block_on(self.inner.subscribe(channels)) } pub fn unsubscribe(&mut self, channels: &[String]) -> crate::Result<()> { self.rt.block_on(self.inner.unsubscribe(channels)) } } }
由上可知,subscribe 方法会使用运行时将一个异步的 Client 转变成一个异步的 Subscriber,此外,Subscriber 结构体有一个非异步的方法 get_subscribed,对于这种方法,只需直接调用即可,而无需使用运行时。
其它方法
上面介绍的是最简单的方法,但是,如果只有这一种, tokio 也不会如此大名鼎鼎。
runtime.spawn
可以通过 Runtime 的 spawn 方法来创建一个基于该运行时的后台任务:
use tokio::runtime::Builder; use tokio::time::{sleep, Duration}; fn main() { let runtime = Builder::new_multi_thread() .worker_threads(1) .enable_all() .build() .unwrap(); let mut handles = Vec::with_capacity(10); for i in 0..10 { handles.push(runtime.spawn(my_bg_task(i))); } // 在后台任务运行的同时做一些耗费时间的事情 std::thread::sleep(Duration::from_millis(750)); println!("Finished time-consuming task."); // 等待这些后台任务的完成 for handle in handles { // `spawn` 方法返回一个 `JoinHandle`,它是一个 `Future`,因此可以通过 `block_on` 来等待它完成 runtime.block_on(handle).unwrap(); } } async fn my_bg_task(i: u64) { let millis = 1000 - 50 * i; println!("Task {} sleeping for {} ms.", i, millis); sleep(Duration::from_millis(millis)).await; println!("Task {} stopping.", i); }
运行该程序,输出如下:
Task 0 sleeping for 1000 ms.
Task 1 sleeping for 950 ms.
Task 2 sleeping for 900 ms.
Task 3 sleeping for 850 ms.
Task 4 sleeping for 800 ms.
Task 5 sleeping for 750 ms.
Task 6 sleeping for 700 ms.
Task 7 sleeping for 650 ms.
Task 8 sleeping for 600 ms.
Task 9 sleeping for 550 ms.
Task 9 stopping.
Task 8 stopping.
Task 7 stopping.
Task 6 stopping.
Finished time-consuming task.
Task 5 stopping.
Task 4 stopping.
Task 3 stopping.
Task 2 stopping.
Task 1 stopping.
Task 0 stopping.
在此例中,我们生成了 10 个后台任务在运行时中运行,然后等待它们的完成。作为一个例子,想象一下在图形渲染应用( GUI )中,有时候需要通过网络访问远程服务来获取一些数据,那上面的这种模式就非常适合,因为这些网络访问比较耗时,而且不会影响图形的主体渲染,因此可以在主线程中渲染图形,然后使用其它线程来运行 Tokio 的运行时,并通过该运行时使用异步的方式完成网络访问,最后将这些网络访问的结果发送到 GUI 进行数据渲染,例如一个进度条。
还有一点很重要,在本例子中只能使用 multi_thread 运行时。如果我们使用了 current_thread,你会发现主线程的耗时任务会在后台任务开始之前就完成了。因为在 current_thread 模式下,生成的任务只会在 block_on 期间才执行。
在 multi_thread 模式下,我们并不需要通过 block_on 来触发任务的运行,这里仅仅是用来阻塞并等待最终的结果。而除了通过 block_on 等待结果外,你还可以:
- 使用消息传递的方式,例如
tokio::sync::mpsc,让异步任务将结果发送到主线程,然后主线程通过.recv方法等待这些结果 - 通过共享变量的方式,例如
Mutex,这种方式非常适合实现 GUI 的进度条: GUI 在每个渲染帧读取该变量即可。
发送消息
在同步代码中使用异步的另一个方法就是生成一个运行时,然后使用消息传递的方式跟它进行交互。这个方法虽然更啰嗦一些,但是相对于之前的两种方法更加灵活:
#![allow(unused)] fn main() { use tokio::runtime::Builder; use tokio::sync::mpsc; pub struct Task { name: String, // 一些信息用于描述该任务 } async fn handle_task(task: Task) { println!("Got task {}", task.name); } #[derive(Clone)] pub struct TaskSpawner { spawn: mpsc::Sender<Task>, } impl TaskSpawner { pub fn new() -> TaskSpawner { // 创建一个消息通道用于通信 let (send, mut recv) = mpsc::channel(16); let rt = Builder::new_current_thread() .enable_all() .build() .unwrap(); std::thread::spawn(move || { rt.block_on(async move { while let Some(task) = recv.recv().await { tokio::spawn(handle_task(task)); } // 一旦所有的发送端超出作用域被 drop 后,`.recv()` 方法会返回 None,同时 while 循环会退出,然后线程结束 }); }); TaskSpawner { spawn: send, } } pub fn spawn_task(&self, task: Task) { match self.spawn.blocking_send(task) { Ok(()) => {}, Err(_) => panic!("The shared runtime has shut down."), } } } }
为何说这种方法比较灵活呢?以上面代码为例,它可以在很多方面进行配置。例如,可以使用信号量 Semaphore来限制当前正在进行的任务数,或者你还可以使用一个消息通道将消息反向发送回任务生成器 spawner。
抛开细节,抽象来看,这是不是很像一个 Actor ?
Cargo 使用指南
Rust 语言的名气之所以这么大,保守估计 Cargo 的贡献就占了三分之一。
Cargo 是包管理工具,可以用于依赖包的下载、编译、更新、分发等,与 Cargo 一样有名的还有 crates.io,它是社区提供的包注册中心:用户可以将自己的包发布到该注册中心,然后其它用户通过注册中心引入该包。
本章内容是基于 Cargo Book 翻译,并做了一些内容优化和目录组织上的调整
上手使用
Cargo 会在安装 Rust 的时候一并进行安装,无需我们手动的操作执行,安装 Rust 参见这里。
在开始之前,先来明确一个名词: Package,由于 Crate 被翻译成包,因此 Package 再被翻译成包就很不合适,经过斟酌,我们决定翻译成项目,你也可以理解为工程、软件包,总之,在本书中Package 意味着项目,而项目也意味着 Package 。
安装完成后,接下来使用 Cargo 来创建一个新的二进制项目,二进制意味着该项目可以作为一个服务运行或被编译成可执行文件运行。
#![allow(unused)] fn main() { $ cargo new hello_world }
这里我们使用 cargo new 创建一个新的项目 ,事实上该命令等价于 cargo new hello_world --bin,bin 是 binary 的简写,代表着二进制程序,由于 --bin 是默认参数,因此可以对其进行省略。
创建成功后,先来看看项目的基本目录结构长啥样:
$ cd hello_world
$ tree .
.
├── Cargo.toml
└── src
└── main.rs
1 directory, 2 files
这里有一个很显眼的文件 Cargo.toml,一看就知道它是 Cargo 使用的配置文件,这个关系类似于: package.json 是 npm 的配置文件。
[package]
name = "hello_world"
version = "0.1.0"
edition = "2021"
[dependencies]
以上就是 Cargo.toml 的全部内容,它被称之为清单( manifest ),包含了 Cargo 编译程序所需的所有元数据。
下面是 src/main.rs 的内容 :
fn main() { println!("Hello, world!"); }
可以看出 Cargo 还为我们自动生成了一个 hello world 程序,或者说二进制包,对程序进行编译构建:
$ cargo build
Compiling hello_world v0.1.0 (file:///path/to/package/hello_world)
然后再运行编译出的二进制可执行文件:
$ ./target/debug/hello_world
Hello, world!
注意到路径中的 debug 了吗?它说明我们刚才的编译是 Debug 模式,该模式主要用于测试目的,如果想要进行生产编译,我们需要使用 Release 模式 cargo build --release,然后通过 ./target/release/hello_world 运行。
除了上面的编译 + 运行方式外,在日常开发中,我们还可以使用一个简单的命令直接运行:
$ cargo run
Fresh hello_world v0.1.0 (file:///path/to/package/hello_world)
Running `target/hello_world`
Hello, world!
cargo run 会帮我们自动完成编译、运行的过程,当然,该命令也支持 Release 模式: cargo run --release。
如果你的程序在跑性能测试 benchmark,一定要使用
Release模式,因为该模式下,程序会做大量性能优化
在快速了解 Cargo 的使用方式后,下面,我们将正式进入 Cargo 的学习之旅。
使用手册
在本章中,我们将学习 Cargo 的详细使用方式,例如 Package 的创建与管理、依赖拉取、Package 结构描述等。
为何会有 Cargo
根据之前学习的知识,Rust 有两种类型的包: 库包和二进制包,前者是我们经常使用的依赖包,用于被其它包所引入,而后者是一个应用服务,可以编译成二进制可执行文件进行运行。
包是通过 Rust 编译器 rustc 进行编译的:
#![allow(unused)] fn main() { $ rustc hello.rs $ ./hello Hello, world! }
上面我们直接使用 rustc 对二进制包 hello.rs 进行编译,生成二进制可执行文件 hello,并对其进行运行。
该方式虽然简单,但有几个问题:
- 必须要指定文件名编译,当项目复杂后,这种编译方式也随之更加复杂
- 如果要指定编译参数,情况将更加复杂
最关键的是,外部依赖库的引入也将是一个大问题。大部分实际的项目都有不少依赖包,而这些依赖包又间接的依赖了新的依赖包,在这种复杂情况下,如何管理依赖包及其版本也成为一个相当棘手的问题。
正是因为这些原因,与其使用 rustc ,我们可以使用一个强大的包管理工具来解决问题:欢迎 Cargo 闪亮登场。
Cargo
Cargo 解决了之前描述的所有问题,同时它保证了每次重复的构建都不会改变上一次构建的结果,这背后是通过完善且强大的依赖包版本管理来实现的。
总之,Cargo 为了实现目标,做了四件事:
- 引入两个元数据文件,包含项目的方方面面信息:
Cargo.toml和Cargo.lock - 获取和构建项目的依赖,例如
Cargo.toml中的依赖包版本描述,以及从crates.io下载包 - 调用
rustc(或其它编译器) 并使用的正确的参数来构建项目,例如cargo build - 引入一些惯例,让项目的使用更加简单
毫不夸张的说,得益于 Cargo 的标准化,只要你使用它构建过一个项目,那构建其它使用 Cargo 的项目,也将不存在任何困难。
下载并构建 Package
如果看中 GitHub 上的某个开源 Rust 项目,那下载并构建它将是非常简单的。
$ git clone https://github.com/rust-lang/regex.git
$ cd regex
如上所示,直接从 GitHub 上克隆下来想要的项目,然后使用 cargo build 进行构建即可:
$ cargo build
Compiling regex v1.5.0 (file:///path/to/package/regex)
该命令将下载相关的依赖库,等下载成功后,再对 package 和下载的依赖进行一同的编译构建。
这就是包管理工具的强大之处,cargo build 搞定一切,而背后隐藏的复杂配置、参数你都无需关心。
添加依赖
crates.io 是 Rust 社区维护的中心化注册服务,用户可以在其中寻找和下载所需的包。对于 cargo 来说,默认就是从这里下载依赖。
下面我们来添加一个 time 依赖包,若你的 Cargo.toml 文件中没有 [dependencies] 部分,就手动添加一个,并添加目标包名和版本号:
[dependencies]
time = "0.1.12"
可以看到我们指定了 time 包的版本号 "0.1.12",关于版本号,实际上还有其它的指定方式,具体参见指定依赖项章节。
如果想继续添加 regexp 包,只需在 time 包后面添加即可 :
[package]
name = "hello_world"
version = "0.1.0"
edition = "2021"
[dependencies]
time = "0.1.12"
regex = "0.1.41"
此时,再通过运行 cargo build 来重新构建,首先 Cargo 会获取新的依赖以及依赖的依赖, 接着对它们进行编译并更新 Cargo.lock:
$ cargo build
Updating crates.io index
Downloading memchr v0.1.5
Downloading libc v0.1.10
Downloading regex-syntax v0.2.1
Downloading memchr v0.1.5
Downloading aho-corasick v0.3.0
Downloading regex v0.1.41
Compiling memchr v0.1.5
Compiling libc v0.1.10
Compiling regex-syntax v0.2.1
Compiling memchr v0.1.5
Compiling aho-corasick v0.3.0
Compiling regex v0.1.41
Compiling hello_world v0.1.0 (file:///path/to/package/hello_world)
在 Cargo.lock 中包含了我们项目使用的所有依赖的准确版本信息。这个非常重要,未来就算 regexp 的作者升级了该包,我们依然会下载 Cargo.lock 中的版本,而不是最新的版本,只有这样,才能保证项目依赖包不会莫名其妙的因为更新升级导致无法编译。 当然,你还可以使用 cargo update 来手动更新包的版本。
此时,就可以在 src/main.rs 中使用新引入的 regexp 包:
use regex::Regex; fn main() { let re = Regex::new(r"^\d{4}-\d{2}-\d{2}$").unwrap(); println!("Did our date match? {}", re.is_match("2014-01-01")); }
运行后输出:
$ cargo run
Running `target/hello_world`
Did our date match? true
标准的 Package 目录结构
一个典型的 Package 目录结构如下:
.
├── Cargo.lock
├── Cargo.toml
├── src/
│ ├── lib.rs
│ ├── main.rs
│ └── bin/
│ ├── named-executable.rs
│ ├── another-executable.rs
│ └── multi-file-executable/
│ ├── main.rs
│ └── some_module.rs
├── benches/
│ ├── large-input.rs
│ └── multi-file-bench/
│ ├── main.rs
│ └── bench_module.rs
├── examples/
│ ├── simple.rs
│ └── multi-file-example/
│ ├── main.rs
│ └── ex_module.rs
└── tests/
├── some-integration-tests.rs
└── multi-file-test/
├── main.rs
└── test_module.rs
这也是 Cargo 推荐的目录结构,解释如下:
Cargo.toml和Cargo.lock保存在package根目录下- 源代码放在
src目录下 - 默认的
lib包根是src/lib.rs - 默认的二进制包根是
src/main.rs- 其它二进制包根放在
src/bin/目录下
- 其它二进制包根放在
- 基准测试 benchmark 放在
benches目录下 - 示例代码放在
examples目录下 - 集成测试代码放在
tests目录下
关于 Rust 中的包和模块,之前的章节有更详细的解释。
此外,bin、tests、examples 等目录路径都可以通过配置文件进行配置,它们被统一称之为 Cargo Target。
Cargo.toml vs Cargo.lock
Cargo.toml 和 Cargo.lock 是 Cargo 的两个元配置文件,但是它们拥有不同的目的:
- 前者从用户的角度出发来描述项目信息和依赖管理,因此它是由用户来编写
- 后者包含了依赖的精确描述信息,它是由
Cargo自行维护,因此不要去手动修改
它们的关系跟 package.json 和 package-lock.json 非常相似,从 JavaScript 过来的同学应该会比较好理解。
是否上传本地的 Cargo.lock
当本地开发时,Cargo.lock 自然是非常重要的,但是当你要把项目上传到 Git 时,例如 GitHub,那是否上传 Cargo.lock 就成了一个问题。
关于是否上传,有如下经验准则:
- 从实践角度出发,如果你构建的是三方库类型的服务,请把
Cargo.lock加入到.gitignore中。 - 若构建的是一个面向用户终端的产品,例如可以像命令行工具、应用程序一样执行,那就把
Cargo.lock上传到源代码目录中。
例如 axum 是 web 开发框架,它属于三方库类型的服务,因此源码目录中不应该出现 Cargo.lock 的身影,它的归宿是 .gitignore。而 ripgrep 则恰恰相反,因为它是一个面向终端的产品,可以直接运行提供服务。
那么问题来了,为何会有这种选择?
原因是 Cargo.lock 会详尽描述上一次成功构建的各种信息:环境状态、依赖、版本等等,Cargo 可以使用它提供确定性的构建环境和流程,无论何时何地。这种特性对于终端服务是非常重要的:能确定、稳定的在用户环境中运行起来是终端服务最重要的特性之一。
而对于三方库来说,情况就有些不同。它不仅仅被库的开发者所使用,还会间接影响依赖链下游的使用者。用户引入了三方库是不会去看它的 Cargo.lock 信息的,也不应该受这个库的确定性运行条件所限制。
还有个原因,在项目中,可能会有几个依赖库引用同一个三方库的同一个版本,那如果该三方库使用了 Cargo.lock 文件,那可能三方库的多个版本会被引入使用,这时就会造成版本冲突。换句话说,通过指定版本的方式引用一个依赖库是无法看到该依赖库的完整情况的,而只有终端的产品才会看到这些完整的情况。
假设没有 Cargo.lock
Cargo.toml 是一个清单文件( manifest )包含了我们 package 的描述元数据。例如,通过以下内容可以说明对另一个 package 的依赖 :
#![allow(unused)] fn main() { [package] name = "hello_world" version = "0.1.0" [dependencies] regex = { git = "https://github.com/rust-lang/regex.git" } }
可以看到,只有一个依赖,且该依赖的来源是 GitHub 上一个特定的仓库。由于我们没有指定任何版本信息,Cargo 会自动拉取该依赖库的最新版本( master 或 main 分支上的最新 commit )。
这种使用方式,其实就错失了包管理工具的最大的优点:版本管理。例如你在今天构建使用了版本 A,然后过了一段时间后,由于依赖包的升级,新的构建却使用了大更新版本 B,结果因为版本不兼容,导致了构建失败。
可以看出,确保依赖版本的确定性是非常重要的:
#![allow(unused)] fn main() { [dependencies] regex = { git = "https://github.com/rust-lang/regex.git", rev = "9f9f693" } }
这次,我们使用了指定 rev ( revision ) 的方式来构建,那么不管未来何时再次构建,使用的依赖库都会是该 rev ,而不是最新的 commit。
但是,这里还有一个问题:rev 需要手动的管理,你需要在每次更新包的时候都思考下 SHA-1,这显然非常麻烦。
当有了 Cargo.lock 后
当有了 Cargo.lock 后,我们无需手动追踪依赖库的 rev,Cargo 会自动帮我们完成,还是之前的清单:
#![allow(unused)] fn main() { [package] name = "hello_world" version = "0.1.0" [dependencies] regex = { git = "https://github.com/rust-lang/regex.git" } }
第一次构建时,Cargo 依然会拉取最新的 master commit,然后将以下信息写到 Cargo.lock 文件中:
#![allow(unused)] fn main() { [[package]] name = "hello_world" version = "0.1.0" dependencies = [ "regex 1.5.0 (git+https://github.com/rust-lang/regex.git#9f9f693768c584971a4d53bc3c586c33ed3a6831)", ] [[package]] name = "regex" version = "1.5.0" source = "git+https://github.com/rust-lang/regex.git#9f9f693768c584971a4d53bc3c586c33ed3a6831" }
可以看出,其中包含了依赖库的准确 rev 信息。当未来再次构建时,只要项目中还有该 Cargo.lock 文件,那构建依然会拉取同一个版本的依赖库,并且再也无需我们手动去管理 rev 的 SHA 信息!
更新依赖
由于 Cargo.lock 会锁住依赖的版本,你需要通过手动的方式将依赖更新到新的版本:
#![allow(unused)] fn main() { $ cargo update # 更新所有依赖 $ cargo update -p regex # 只更新 “regex” }
以上命令将使用新的版本信息重新生成 Cargo.lock ,需要注意的是 cargo update -p regex 传递的参数实际上是一个 Package ID, regex 只是一个简写形式。
测试和 CI
Cargo 可以通过 cargo test 命令运行项目中的测试文件:它会在 src/ 底下的文件寻找单元测试,也会在 tests/ 目录下寻找集成测试。
#![allow(unused)] fn main() { $ cargo test Compiling regex v1.5.0 (https://github.com/rust-lang/regex.git#9f9f693) Compiling hello_world v0.1.0 (file:///path/to/package/hello_world) Running target/test/hello_world-9c2b65bbb79eabce running 0 tests test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out }
从上面结果可以看出,项目中实际上还没有任何测试代码。
事实上,除了单元测试、集成测试,cargo test 还会编译 examples/ 下的示例文件以及文档中的示例。
如果希望深入学习如何在 Rust 编写及运行测试,请查阅该章节。
CI
持续集成是软件开发中异常重要的一环,大家应该都听说过 Jenkins,它就是一个拥有悠久历史的持续集成工具。简单来说,持续集成会定期拉取同一个项目中所有成员的相关代码,对其进行自动化构建。
在没有持续集成前,首先开发者需要手动编译代码并运行单元测试、集成测试等基础测试,然后启动项目相关的所有服务,接着测试人员开始介入对整个项目进行回归测试、黑盒测试等系统化的测试,当测试通过后,最后再手动发布到指定的环境中运行,这个过程是非常冗长,且所有成员都需要同时参与的。
在有了持续集成后,只要编写好相应的编译、测试、发布配置文件,那持续集成平台会自动帮助我们完成整个相关的流程,期间无需任何人介入,高效且可靠。
GitHub Actions
关于如何使用 GitHub Actions 进行持续集成,在之前的章节已经有过详细的介绍,这里就不再赘述。
Travis CI
以下是 Travis CI 需要的一个简单的示例配置文件:
language: rust
rust:
- stable
- beta
- nightly
matrix:
allow_failures:
- rust: nightly
以上配置将测试所有的 Rust 发布版本,但是 nightly 版本的构建失败不会导致全局测试的失败,可以查看 Travis CI Rust 文档 获取更详细的说明。
Gitlab CI
以下是一个示例 .gitlab-ci.yml 文件:
stages:
- build
rust-latest:
stage: build
image: rust:latest
script:
- cargo build --verbose
- cargo test --verbose
rust-nightly:
stage: build
image: rustlang/rust:nightly
script:
- cargo build --verbose
- cargo test --verbose
allow_failure: true
这里将测试 stable 和 nightly 发布版本,同样的,nightly 下的测试失败不会导致全局测试的失败。查看 Gitlab CI 文档 获取更详细的说明。
Cargo 缓存
Cargo 使用了缓存的方式提升构建效率,当构建时,Cargo 会将已下载的依赖包放在 CARGO_HOME 目录下,下面一起来看看。
Cargo Home
默认情况下,Cargo Home 所在的目录是 $HOME/.cargo/,例如在 macos ,对应的目录是:
$ echo $HOME/.cargo/
/Users/sunfei/.cargo/
我们也可以通过修改 CARGO_HOME 环境变量的方式来重新设定该目录的位置。若你需要在项目中通过代码的方式来获取 CARGO_HOME ,home 包提供了相应的 API。
注意! Cargo Home 目录的内部结构并没有稳定化,在未来可能会发生变化
文件
config.toml是 Cargo 的全局配置文件,具体请查看这里credentials.toml为cargo login提供私有化登录证书,用于登录package注册中心,例如crates.io.crates.toml,.crates2.json这两个是隐藏文件,包含了通过cargo install安装的包的package信息,请不要手动修改!
目录
bin目录包含了通过cargo install或rustup下载的包编译出的可执行文件。你可以将该目录加入到$PATH环境变量中,以实现对这些可执行文件的直接访问git中存储了Git的资源文件:git/db,当一个包依赖某个git仓库时,Cargo会将该仓库克隆到git/db目录下,如果未来需要还会对其进行更新git/checkouts,若指定了git源和commit,那相应的仓库就会从git/db中checkout到该目录下,因此同一个仓库的不同checkout共存成为了可能性
registry包含了注册中心( 例如crates.io)的元数据 和packagesregistry/index是一个 git 仓库,包含了注册中心中所有可用包的元数据( 版本、依赖等 )registry/cache中保存了已下载的依赖,这些依赖包以gzip的压缩档案形式保存,后缀名为.crateregistry/src,若一个已下载的.crate档案被一个package所需要,该档案会被解压缩到registry/src文件夹下,最终rustc可以在其中找到所需的.rs文件
在 CI 时缓存 Cargo Home
为了避免持续集成时重复下载所有的包依赖,我们可以将 $CARGO_HOME 目录进行缓存,但缓存整个目录是效率低下的,原因是源文件可能会被缓存两次。
例如我们依赖一个包 serde 1.0.92,如果将整个 $CACHE_HOME 目录缓存,那么serde 的源文件就会被缓存两次:在 registry/cache 中的 serde-1.0.92.crate 以及 registry/src 下被解压缩的 .rs 文件。
因此,在 CI 构建时,出于效率的考虑,我们仅应该缓存以下目录:
bin/registry/index/registry/cache/git/db/
清除缓存
理论上,我们可以手动移除缓存中的任何一部分,当后续有包需要时 Cargo 会尽可能去恢复这些资源:
- 解压缩
registry/cache下的.crate档案 - 从
.git中checkout缓存的仓库 - 如果以上都没了,会从网络上重新下载
你也可以使用 cargo-cache 包来选择性的清除 cache 中指定的部分,当然,它还可以用来查看缓存中的组件大小。
构建时卡住:Blocking waiting for file lock ..
在开发过程中,或多或少我们都会碰到这种问题,例如你同时打开了 VSCode IDE 和终端,然后在 Cargo.toml 中刚添加了一个新的依赖。
此时 IDE 会捕捉到这个修改然后自动去重新下载依赖(这个过程可能还会更新 crates.io 使用的索引列表),在此过程中, Cargo 会将相关信息写入到 $HOME/.cargo/.package_cache 下,并将其锁住。
如果你试图在另一个地方(例如终端)对同一个项目进行构建,就会报错: Blocking waiting for file lock on package cache。
解决办法很简单:
- 既然下载慢,那就使用国内的注册服务,不再使用 crates.io
- 耐心等待持有锁的用户构建完成
- 强行停止正在构建的进程,例如杀掉 IDE 使用的 rust-analyer 插件进程,然后删除
$HOME/.cargo/.package_cache目录
构建( Build )缓存
cargo build 的结果会被放入项目根目录下的 target 文件夹中,当然,这个位置可以三种方式更改:设置 CARGO_TARGET_DIR 环境变量、build.target-dir 配置项以及 --target-dir 命令行参数。
target 目录结构
target 目录的结构取决于是否使用 --target 标志为特定的平台构建。
不使用 --target
若 --target 标志没有指定,Cargo 会根据宿主机架构进行构建,构建结果会放入项目根目录下的 target 目录中,target 下每个子目录中包含了相应的 发布配置profile 的构建结果,例如 release、debug 是自带的profile,前者往往用于生产环境,因为会做大量的性能优化,而后者则用于开发环境,此时的编译效率和报错信息是最好的。
除此之外我们还可以定义自己想要的 profile ,例如用于测试环境的 profile: test,用于预发环境的 profile :pre-prod 等。
| 目录 | 描述 |
|---|---|
target/debug/ | 包含了 dev profile 的构建输出(cargo build 或 cargo build --debug) |
target/release/ | release profile 的构建输出,cargo build --release |
target/foo/ | 自定义 foo profile 的构建输出,cargo build --profile=foo |
出于历史原因:
dev和testprofile 的构建结果都存放在debug目录下release和benchprofile 则存放在release目录下- 用户定义的 profile 存在同名的目录下
使用 --target
当使用 --target XXX 为特定的平台编译后,输出会放在 target/XXX/ 目录下:
| 目录 | 示例 |
|---|---|
target/<triple>/debug/ | target/thumbv7em-none-eabihf/debug/ |
target/<triple>/release/ | target/thumbv7em-none-eabihf/release/ |
注意:,当没有使用
--target时,Cargo会与构建脚本和过程宏一起共享你的依赖包,对于每个rustc命令调用而言,RUSTFLAGS也将被共享。而使用
--target后,构建脚本、过程宏会针对宿主机的 CPU 架构进行各自构建,且不会共享RUSTFLAGS。
target 子目录说明
在 profile 文件夹中(例如 debug 或 release),包含编译后的最终成果:
| 目录 | 描述 |
|---|---|
target/debug/ | 包含编译后的输出,例如二进制可执行文件、库对象( library target ) |
target/debug/examples/ | 包含示例对象( example target ) |
还有一些命令会在 target 下生成自己的独立目录:
| 目录 | 描述 |
|---|---|
target/doc/ | 包含通过 cargo doc 生成的文档 |
target/package/ | 包含 cargo package 或 cargo publish 生成的输出 |
Cargo 还会创建几个用于构建过程的其它类型目录,它们的目录结构只应该被 Cargo 自身使用,因此可能会在未来发生变化:
| 目录 | 描述 |
|---|---|
target/debug/deps | 依赖和其它输出成果 |
target/debug/incremental | rustc 增量编译的输出,该缓存可以用于提升后续的编译速度 |
target/debug/build/ | 构建脚本的输出 |
依赖信息文件
在每一个编译成果的旁边,都有一个依赖信息文件,文件后缀是 .d。该文件的语法类似于 Makefile,用于说明构建编译成果所需的所有依赖包。
该文件往往用于提供给外部的构建系统,这样它们就可以判断 Cargo 命令是否需要再次被执行。
文件中的路径默认是绝对路径,你可以通过 build.dep-info-basedir 配置项来修改为相对路径。
# 关于 `.d` 文件的一个示例 : target/debug/foo.d
/path/to/myproj/target/debug/foo: /path/to/myproj/src/lib.rs /path/to/myproj/src/main.rs
共享缓存
sccache 是一个三方工具,可以用于在不同的工作空间中共享已经构建好的依赖包。
为了设置 sccache,首先需要使用 cargo install sccache 进行安装,然后在调用 Cargo 之前将 RUSTC_WRAPPER 环境变量设置为 sccache。
- 如果用的
bash,可以将export RUSTC_WRAPPER=sccache添加到.bashrc中 - 也可以使用
build.rustc-wrapper配置项
进阶指南
进阶指南包含了 Cargo 的参考级内容,大家可以先看一遍了解下大概有什么,然后在后面需要时,再回来查询如何使用。
指定依赖项
我们的项目可以引用在 crates.io 或 GitHub 上的依赖包,也可以引用存放在本地文件系统中的依赖包。
大家可能会想,直接从前两个引用即可,为何还提供了本地方式?可以设想下,如果你要有一个正处于开发中的包,然后需要在本地的另一个项目中引用测试,那是将该包先传到网上,然后再引用简单,还是直接从本地路径的方式引用简单呢?答案显然不言而喻。
本章节,我们一起来看看有哪些方式可以指定和引用三方依赖包。
从 crates.io 引入依赖包
默认设置下,Cargo 就从 crates.io 上下载依赖包,只需要一个包名和版本号即可:
[dependencies]
time = "0.1.12"
字符串 "0.1.12" 是一个 semver 格式的版本号,符合 "x.y.z" 的形式,其中 x 被称为主版本major, y 被称为小版本 minor ,而 z 被称为补丁 patch,可以看出从左到右,版本的影响范围逐步降低,补丁的更新是无关痛痒的,并不会造成 API 的兼容性被破坏。
"0.1.12" 中并没有任何额外的符号,在版本语义上,它跟使用了 ^ 的 "^0.1.12" 是相同的,都是指定非常具体的版本进行引入。
但是 ^ 能做的更多。
npm 使用的就是
semver版本号,从JavaScript过来的同学应该非常熟悉。
^ 指定版本
与之前的 "0.1.12" 不同, ^ 可以指定一个版本号范围,然后会使用该范围内的最大版本号来引用对应的包。
只要新的版本号没有修改最左边的非零数字,那该版本号就在允许的版本号范围中。例如 "^0.1.12" 最左边的非零数字是 1,因此,只要新的版本号是 "0.1.z" 就可以落在范围内,而0.2.0 显然就没有落在范围内,因此通过 "^0.1.12" 引入的依赖包是无法被升级到 0.2.0 版本的。
同理,若是 "^1.0",则 1.1 在范围中,2.0 则不在。 大家思考下,"^0.0.1" 与哪些版本兼容?答案是:无,因为它最左边的数字是 1 ,而该数字已经退无可退,我们又不能修改 1,因此没有版本落在范围中。
^1.2.3 := >=1.2.3, <2.0.0
^1.2 := >=1.2.0, <2.0.0
^1 := >=1.0.0, <2.0.0
^0.2.3 := >=0.2.3, <0.3.0
^0.2 := >=0.2.0, <0.3.0
^0.0.3 := >=0.0.3, <0.0.4
^0.0 := >=0.0.0, <0.1.0
^0 := >=0.0.0, <1.0.0
以上是更多的例子,事实上,这个规则跟 SemVer 还有所不同,因为对于 SemVer 而言,0.x.y 的版本是没有其它版本与其兼容的,而对于 Rust,只要版本号 0.x.y 满足 : z>=y 且 x>0 的条件,那它就能更新到 0.x.z 版本。
~ 指定版本
~ 指定了最小化版本 :
#![allow(unused)] fn main() { ~1.2.3 := >=1.2.3, <1.3.0 ~1.2 := >=1.2.0, <1.3.0 ~1 := >=1.0.0, <2.0.0 }
* 通配符
这种方式允许将 * 所在的位置替换成任何数字:
#![allow(unused)] fn main() { * := >=0.0.0 1.* := >=1.0.0, <2.0.0 1.2.* := >=1.2.0, <1.3.0 }
不过 crates.io 并不允许我们只使用孤零零一个 * 来指定版本号 : *。
比较符
可以使用比较符的方式来指定一个版本号范围或一个精确的版本号:
#![allow(unused)] fn main() { >= 1.2.0 > 1 < 2 = 1.2.3 }
同时还能使用比较符进行组合,并通过逗号分隔:
#![allow(unused)] fn main() { >= 1.2, < 1.5 }
需要注意,以上的版本号规则仅仅针对 crate.io 和基于它搭建的注册服务(例如科大服务源) ,其它注册服务(例如 GitHub )有自己相应的规则。
从其它注册服务引入依赖包
为了使用 crates.io 之外的注册服务,我们需要对 $HOME/.cargo/config.toml (CARGO_HOME 下) 文件进行配置,添加新的服务提供商,有两种方式可以实现。
由于国内访问国外注册服务的不稳定性,我们可以使用科大的注册服务来提升下载速度,以下注册服务的链接都是科大的
首先是在 crates.io 之外添加新的注册服务,修改 .cargo/config.toml 添加以下内容:
[registries]
ustc = { index = "https://mirrors.ustc.edu.cn/crates.io-index/" }
对于这种方式,我们的项目的 Cargo.toml 中的依赖包引入方式也有所不同:
[dependencies]
time = { registry = "ustc" }
在重新配置后,初次构建可能要较久的时间,因为要下载更新 ustc 注册服务的索引文件,还挺大的...
注意,这一种使用方式最大的缺点就是在引用依赖包时要指定注册服务: time = { registry = "ustc" }。
而第二种方式就不需要,因为它是直接使用新注册服务来替代默认的 crates.io。
[source.crates-io]
replace-with = 'ustc'
[source.ustc]
registry = "git://mirrors.ustc.edu.cn/crates.io-index"
上面配置中的第一个部分,首先将源 source.crates-io 替换为 ustc,然后在第二部分指定了 ustc 源的地址。
注意,如果你要发布包到
crates.io上,那该包的依赖也必须在crates.io上
引入 git 仓库作为依赖包
若要引入 git 仓库中的库作为依赖包,你至少需要提供一个仓库的地址:
[dependencies]
regex = { git = "https://github.com/rust-lang/regex" }
由于没有指定版本,Cargo 会假定我们使用 master 或 main 分支的最新 commit 。你可以使用 rev、tag 或 branch 来指定想要拉取的版本。例如下面代码拉取了 next 分支上的最新 commit:
[dependencies]
regex = { git = "https://github.com/rust-lang/regex", branch = "next" }
任何非 tag 和 branch 的类型都可以通过 rev 来引入,例如通过最近一次 commit 的哈希值引入: rev = "4c59b707",再比如远程仓库提供的的具名引用: rev = "refs/pull/493/head"。
一旦 git 依赖被拉取下来,该版本就会被记录到 Cargo.lock 中进行锁定。因此 git 仓库中后续新的提交不再会被自动拉取,除非你通过 cargo update 来升级。需要注意的是锁定一旦被删除,那 Cargo 依然会按照 Cargo.toml 中配置的地址和版本去拉取新的版本,如果你配置的版本不正确,那可能会拉取下来一个不兼容的新版本!
因此不要依赖锁定来完成版本的控制,而应该老老实实的在 Cargo.toml 小心配置你希望使用的版本。
如果访问的是私有仓库,你可能需要授权来访问该仓库,可以查看这里了解授权的方式。
通过路径引入本地依赖包
Cargo 支持通过路径的方式来引入本地的依赖包:一般来说,本地依赖包都是同一个项目内的内部包,例如假设我们有一个 hello_world 项目( package ),现在在其根目录下新建一个包:
# 在 hello_world/ 目录下
cargo new hello_utils
新建的 hello_utils 文件夹跟 src、Cargo.toml 同级,现在修改 Cargo.toml 让 hello_world 项目引入新建的包:
[dependencies]
hello_utils = { path = "hello_utils" }
# 以下路径也可以
# hello_utils = { path = "./hello_utils" }
# hello_utils = { path = "../hello_world/hello_utils" }
但是,此时的 hello_world 是无法发布到 crates.io 上的。想要发布,需要先将 hello_utils 先发布到 crates.io 上,然后再通过 crates.io 的方式来引入:
[dependencies]
hello_utils = { path = "hello_utils", version = "0.1.0" }
注意!使用
path指定依赖的 package 将无法发布到crates.io,除非path存在于 [dev-dependencies] 中。当然,你还可以使用多种引用混合的方式来解决这个问题,下面将进行介绍
多引用方式混合
实际上,我们可以同时使用多种方式来引入同一个包,例如本地引入和 crates.io :
[dependencies]
# 本地使用时,通过 path 引入,
# 发布到 `crates.io` 时,通过 `crates.io` 的方式引入: version = "1.0"
bitflags = { path = "my-bitflags", version = "1.0" }
# 本地使用时,通过 git 仓库引入
# 当发布时,通过 `crates.io` 引入: version = "1.0"
smallvec = { git = "https://github.com/servo/rust-smallvec", version = "1.0" }
# N.B. 若 version 无法匹配,Cargo 将无法编译
这种方式跟下章节将要讲述的依赖覆盖类似,但是前者只会应用到当前声明的依赖包上。
根据平台引入依赖
我们还可以根据特定的平台来引入依赖:
[target.'cfg(windows)'.dependencies]
winhttp = "0.4.0"
[target.'cfg(unix)'.dependencies]
openssl = "1.0.1"
[target.'cfg(target_arch = "x86")'.dependencies]
native = { path = "native/i686" }
[target.'cfg(target_arch = "x86_64")'.dependencies]
native = { path = "native/x86_64" }
此处的语法跟 Rust 的 #[cfg] 语法非常相像,因此我们还能使用逻辑操作符进行控制:
[target.'cfg(not(unix))'.dependencies]
openssl = "1.0.1"
这里的意思是,当不是 unix 操作系统时,才对 openssl 进行引入。
如果你想要知道 cfg 能够作用的目标,可以在终端中运行 rustc --print=cfg 进行查询。当然,你可以指定平台查询: rustc --print=cfg --target=x86_64-pc-windows-msvc,该命令将对 64bit 的 Windows 进行查询。
聪明的同学已经发现,这非常类似于条件依赖引入,那我们是不是可以根据自定义的条件来决定是否引入某个依赖呢?具体答案参见后续的 feature 章节。这里是一个简单的示例:
[dependencies]
foo = { version = "1.0", optional = true }
bar = { version = "1.0", optional = true }
[features]
fancy-feature = ["foo", "bar"]
但是需要注意的是,你如果妄图通过 cfg(feature)、cfg(debug_assertions), cfg(test) 和 cfg(proc_macro) 的方式来条件引入依赖,那是不可行的。
Cargo 还允许通过下面的方式来引入平台特定的依赖:
[target.x86_64-pc-windows-gnu.dependencies]
winhttp = "0.4.0"
[target.i686-unknown-linux-gnu.dependencies]
openssl = "1.0.1"
自定义 target 引入
如果你在使用自定义的 target :例如 --target bar.json,那么可以通过下面方式来引入依赖:
[target.bar.dependencies]
winhttp = "0.4.0"
[target.my-special-i686-platform.dependencies]
openssl = "1.0.1"
native = { path = "native/i686" }
需要注意,这种使用方式在
stable版本的 Rust 中无法被使用,建议大家如果没有特别的需求,还是使用之前提到的 feature 方式
[dev-dependencies]
你还可以为项目添加只在测试时需要的依赖库,类似于 package.json( Nodejs )文件中的 devDependencies,可以在 Cargo.toml 中添加 [dev-dependencies] 来实现:
[dev-dependencies]
tempdir = "0.3"
这里的依赖只会在运行测试、示例和 benchmark 时才会被引入。并且,假设A 包引用了 B,而 B 通过 [dev-dependencies] 的方式引用了 C 包, 那 A 是不会引用 C 包的。
当然,我们还可以指定平台特定的测试依赖包:
[target.'cfg(unix)'.dev-dependencies]
mio = "0.0.1"
注意,当发布包到 crates.io 时,
[dev-dependencies]中的依赖只有指定了version的才会被包含在发布包中。况且,再加上测试稳定性的考虑,我们建议为[dev-dependencies]中的包指定相应的版本号
[build-dependencies]
我们还可以指定某些依赖仅用于构建脚本:
[build-dependencies]
cc = "1.0.3"
当然,平台特定的依然可以使用:
[target.'cfg(unix)'.build-dependencies]
cc = "1.0.3"
有一点需要注意:构建脚本( build.rs )和项目的正常代码是彼此独立,因此它们的依赖不能互通: 构建脚本无法使用 [dependencies] 或 [dev-dependencies] 中的依赖,而 [build-dependencies] 中的依赖也无法被构建脚本之外的代码所使用。
选择 features
如果你依赖的包提供了条件性的 features,你可以指定使用哪一个:
[dependencies.awesome]
version = "1.3.5"
default-features = false # 不要包含默认的 features,而是通过下面的方式来指定
features = ["secure-password", "civet"]
更多的信息参见 Features 章节
在 Cargo.toml 中重命名依赖
如果你想要实现以下目标:
- 避免在 Rust 代码中使用
use foo as bar - 依赖某个包的多个版本
- 依赖来自于不同注册服务的同名包
那可以使用 Cargo 提供的 package key :
[package]
name = "mypackage"
version = "0.0.1"
[dependencies]
foo = "0.1"
bar = { git = "https://github.com/example/project", package = "foo" }
baz = { version = "0.1", registry = "custom", package = "foo" }
此时,你的代码中可以使用三个包:
#![allow(unused)] fn main() { extern crate foo; // 来自 crates.io extern crate bar; // 来自 git repository extern crate baz; // 来自 registry `custom` }
有趣的是,由于这三个 package 的名称都是 foo(在各自的 Cargo.toml 中定义),因此我们显式的通过 package = "foo" 的方式告诉 Cargo:我们需要的就是这个 foo package,虽然它被重命名为 bar 或 baz。
有一点需要注意,当使用可选依赖时,如果你将 foo 包重命名为 bar 包,那引用前者的 feature 时的路径名也要做相应的修改:
[dependencies]
bar = { version = "0.1", package = 'foo', optional = true }
[features]
log-debug = ['bar/log-debug'] # 若使用 'foo/log-debug' 会导致报错
依赖覆盖
依赖覆盖对于本地开发来说,是很常见的,大部分原因都是我们希望在某个包发布到 crates.io 之前使用它,例如:
- 你正在同时开发一个包和一个项目,而后者依赖于前者,你希望能在该项目中对正在开发的包进行测试
- 你引入的一个依赖包在
master分支发布了新的代码,恰好修复了某个 bug,因此你希望能单独对该分支进行下测试 - 你即将发布一个包的新版本,为了确保新版本正常工作,你需要对其进行集成测试
- 你为项目的某个依赖包提了一个 PR 并解决了一个重要 bug,在等待合并到
master分支,但是时间不等人,因此你决定先使用自己修改的版本,等未来合并后,再继续使用官方版本
下面我们来具体看看类似的问题该如何解决。
上一章节中我们讲了如果通过多种引用方式来引入一个包,其实这也是一种依赖覆盖。
测试 bugfix 版本
假设我们有一个项目正在使用 uuid 依赖包,但是却不幸地发现了一个 bug,由于这个 bug 影响了使用,没办法等到官方提交新版本,因此还是自己修复为好。
我们项目的 Cargo.toml 内容如下:
[package]
name = "my-library"
version = "0.1.0"
[dependencies]
uuid = "0.8.2"
为了修复 bug,首先需要将 uuid 的源码克隆到本地,笔者是克隆到和项目同级的目录下:
git clone https://github.com/uuid-rs/uuid
下面,修改项目的 Cargo.toml 添加以下内容以引入本地克隆的版本:
[patch.crates-io]
uuid = { path = "../uuid" }
这里我们使用自己修改过的 patch 来覆盖来自 crates.io 的版本,由于克隆下来的 uuid 目录和我们的项目同级,因此通过相对路径 "../uuid" 即可定位到。
在成功为 uuid 打了本地补丁后,现在尝试在项目下运行 cargo build,但是却报错了,而且报错内容有一些看不太懂:
$ cargo build
Updating crates.io index
warning: Patch `uuid v1.0.0-alpha.1 (/Users/sunfei/development/rust/demos/uuid)` was not used in the crate graph.
Check that the patched package version and available features are compatible
with the dependency requirements. If the patch has a different version from
what is locked in the Cargo.lock file, run `cargo update` to use the new
version. This may also occur with an optional dependency that is not enabled.
具体原因比较复杂,但是仔细观察,会发现克隆下来的 uuid 的版本是 v1.0.0-alpha.1 (在 "../uuid/Cargo.toml" 中可以查看),然后我们本地引入的 uuid 版本是 0.8.2,根据之前讲过的 crates.io 的版本规则,这两者是不兼容的,0.8.2 只能升级到 0.8.z,例如 0.8.3。
既然如此,我们先将 "../uuid/Cargo.toml" 中的 version = "1.0.0-alpha.1" 修改为 version = "0.8.3" ,然后看看结果先:
$ cargo build
Updating crates.io index
Compiling uuid v0.8.3 (/Users/sunfei/development/rust/demos/uuid)
大家注意到最后一行了吗?我们成功使用本地的 0.8.3 版本的 uuid 作为最新的依赖,因此也侧面证明了,补丁 patch 的版本也必须遵循相应的版本兼容规则!
如果修改后还是有问题,大家可以试试以下命令,指定版本进行更新:
% cargo update -p uuid --precise 0.8.3
Updating crates.io index
Updating uuid v0.8.3 (/Users/sunfei/development/rust/demos/uuid) -> v0.8.3
修复 bug 后,我们可以提交 pr 给 uuid,一旦 pr 被合并到了 master 分支,你可以直接通过以下方式来使用补丁:
[patch.crates-io]
uuid = { git = 'https://github.com/uuid-rs/uuid' }
等未来新的内容更新到 crates.io 后,大家就可以移除这个补丁,直接更新 [dependencies] 中的 uuid 版本即可!
使用未发布的小版本
还是 uuid 包,这次假设我们要为它新增一个特性,同时我们已经修改完毕,在本地测试过,并提交了相应的 pr,下面一起来看看该如何在它发布到 crates.io 之前继续使用。
再做一个假设,对于 uuid 来说,目前 crates.io 上的版本是 1.0.0,在我们提交了 pr 并合并到 master 分支后,master 上的版本变成了 1.0.1,这意味着未来 crates.io 上的版本也将变成 1.0.1。
为了使用新加的特性,同时当该包在未来发布到 crates.io 后,我们可以自动使用 crates.io 上的新版本,而无需再使用 patch 补丁,可以这样修改 Cargo.toml:
[package]
name = "my-library"
version = "0.1.0"
[dependencies]
uuid = "1.0.1"
[patch.crates-io]
uuid = { git = 'https://github.com/uuid-rs/uuid' }
注意,我们将 [dependencies] 中的 uuid 版本提前修改为 1.0.1,由于该版本在 crates.io 尚未发布,因此 patch 版本会被使用。
现在,我们的项目是基于 patch 版本的 uuid 来构建,也就是从 gihtub 的 master 分支中拉取最新的 commit 来构建。一旦未来 crates.io 上有了 1.0.1 版本,那项目就会继续基于 crates.io 来构建,此时,patch 就可以删除了。
间接使用 patch
现在假设项目 A 的依赖是 B 和 uuid,而 B 的依赖也是 uuid,此时我们可以让 A 和 B 都使用来自 GitHub 的 patch 版本,配置如下:
[package]
name = "my-binary"
version = "0.1.0"
[dependencies]
my-library = { git = 'https://example.com/git/my-library' }
uuid = "1.0.1"
[patch.crates-io]
uuid = { git = 'https://github.com/uuid-rs/uuid' }
如上所示,patch 不仅仅对于 my-binary 项目有用,对于 my-binary 的依赖 my-library 来说,一样可以间接生效。
非 crates.io 的 patch
若我们想要覆盖的依赖并不是来自 crates.io ,就需要对 [patch] 做一些修改。例如依赖是 git 仓库,然后使用本地路径来覆盖它:
[patch."https://github.com/your/repository"]
my-library = { path = "../my-library/path" }
easy,轻松搞定!
使用未发布的大版本
现在假设我们要发布一个大版本 2.0.0,与之前类似,可以将 Cargo.toml 修改如下:
[dependencies]
uuid = "2.0"
[patch.crates-io]
uuid = { git = "https://github.com/uuid-rs/uuid", branch = "2.0.0" }
此时 2.0 版本在 crates.io 上还不存在,因此我们使用了 patch 版本且指定了 branch = "2.0.0"。
间接使用 patch
这里需要注意,与之前的小版本不同,大版本的 patch 不会发生间接的传递!,例如:
[package]
name = "my-binary"
version = "0.1.0"
[dependencies]
my-library = { git = 'https://example.com/git/my-library' }
uuid = "1.0"
[patch.crates-io]
uuid = { git = 'https://github.com/uuid-rs/uuid', branch = '2.0.0' }
以上配置中, my-binary 将继续使用 1.x.y 系列的版本,而 my-library 将使用最新的 2.0.0 patch。
原因是,大版本更新往往会带来破坏性的功能,Rust 为了让我们平稳的升级,采用了滚动的方式:在依赖图中逐步推进更新,而不是一次性全部更新。
多版本[patch]
在之前章节,我们介绍过如何使用 package key 来重命名依赖包,现在来看看如何使用它同时引入多个 patch。
假设,我们对 serde 有两个新的 patch 需求:
serde官方解决了一个bug但是还没发布到crates.io,我们想直接从git仓库的最新commit拉取版本1.*- 我们自己为
serde添加了新的功能,命名为2.0.0版本,并将该版本上传到自己的git仓库中
为了满足这两个 patch,可以使用如下内容的 Cargo.toml:
[patch.crates-io]
serde = { git = 'https://github.com/serde-rs/serde' }
serde2 = { git = 'https://github.com/example/serde', package = 'serde', branch = 'v2' }
第一行说明,第一个 patch 从官方仓库 main 分支的最新 commit 拉取,而第二个则从我们自己的仓库拉取 v2 分支,同时将其重命名为 serde2。
这样,在代码中就可以分别通过 serde 和 serde2 引用不同版本的依赖库了。
通过[path]来覆盖依赖
有时我们只是临时性地对一个项目进行处理,因此并不想去修改它的 Cargo.toml。此时可以使用 Cargo 提供的路径覆盖方法: 注意,这个方法限制较多,如果可以,还是要使用 [patch]。
与 [patch] 修改 Cargo.toml 不同,路径覆盖修改的是 Cargo 自身的配置文件 $Home/.cargo/config.toml:
paths = ["/path/to/uuid"]
paths 数组中的元素是一个包含 Cargo.toml 的目录(依赖包),在当前例子中,由于我们只有一个 uuid,因此只需要覆盖它即可。目标路径可以是相对的,也是绝对的,需要注意,如果是相对路径,那是相对包含 .cargo 的 $Home 来说的。
不推荐的[replace]
[replace]已经被标记为deprecated,并将在未来被移除,请使用[patch]替代
虽然不建议使用,但是如果大家阅读其它项目时依然可能会碰到这种用法:
[replace]
"foo:0.1.0" = { git = 'https://github.com/example/foo' }
"bar:1.0.2" = { path = 'my/local/bar' }
语法看上去还是很清晰的,[replace] 中的每一个 key 都是 Package ID 格式,通过这种写法可以在依赖图中任意挑选一个节点进行覆盖。
Cargo.toml 格式讲解
Cargo.toml 又被称为清单( manifest ),文件格式是 TOML,每一个清单文件都由以下部分组成:
cargo-features— 只能用于nightly版本的feature[package]— 定义项目(package)的元信息name— 名称version— 版本authors— 开发作者edition— Rust edition.rust-version— 支持的最小化 Rust 版本description— 描述documentation— 文档 URLreadme— README 文件的路径homepage- 主页 URLrepository— 源代码仓库的 URLlicense— 开源协议 License.license-file— License 文件的路径.keywords— 项目的关键词categories— 项目分类workspace— 工作空间 workspace 的路径build— 构建脚本的路径links— 本地链接库的名称exclude— 发布时排除的文件include— 发布时包含的文件publish— 用于阻止项目的发布metadata— 额外的配置信息,用于提供给外部工具default-run— [cargo run] 所使用的默认可执行文件( binary )autobins— 禁止可执行文件的自动发现autoexamples— 禁止示例文件的自动发现autotests— 禁止测试文件的自动发现autobenches— 禁止 bench 文件的自动发现resolver— 设置依赖解析器( dependency resolver)
- Cargo Target 列表: (查看 Target 配置 获取详细设置)
[lib]— Library target 设置.[[bin]]— Binary target 设置.[[example]]— Example target 设置.[[test]]— Test target 设置.[[bench]]— Benchmark target 设置.
- Dependency tables:
[dependencies]— 项目依赖包[dev-dependencies]— 用于 examples、tests 和 benchmarks 的依赖包[build-dependencies]— 用于构建脚本的依赖包[target]— 平台特定的依赖包
[badges]— 用于在注册服务(例如 crates.io ) 上显示项目的一些状态信息,例如当前的维护状态:活跃中、寻找维护者、deprecated[features]—features可以用于条件编译[patch]— 推荐使用的依赖覆盖方式[replace]— 不推荐使用的依赖覆盖方式 (deprecated).[profile]— 编译器设置和优化[workspace]— 工作空间的定义
下面,我们将对其中一些部分进行详细讲解。
[package]
Cargo.toml 中第一个部分就是 package,用于设置项目的相关信息:
[package]
name = "hello_world" # the name of the package
version = "0.1.0" # the current version, obeying semver
authors = ["Alice <a@example.com>", "Bob <b@example.com>"]
其中,只有 name 和 version 字段是必须填写的。当发布到注册服务时,可能会有额外的字段要求,具体参见发布到 crates.io。
name
项目名用于引用一个项目( package ),它有几个用途:
- 其它项目引用我们的
package时,会使用该name - 编译出的可执行文件(bin target)的默认名称
name 只能使用 alphanumeric 字符、 - 和 _,并且不能为空。
事实上,name 的限制不止如此,例如:
- 当使用
cargo new或cargo init创建时,name还会被施加额外的限制,例如不能使用 Rust 关键字名称作为name - 如果要发布到
crates.io,那还有更多的限制:name使用ASCII码,不能使用已经被使用的名称,例如uuid已经在crates.io上被使用,因此我们只能使用类如uuid_v1的名称,才能将项目发布到crates.io上
version
Cargo 使用了语义化版本控制的概念,例如字符串 "0.1.12" 是一个 semver 格式的版本号,符合 "x.y.z" 的形式,其中 x 被称为主版本major, y 被称为小版本 minor ,而 z 被称为补丁 patch,可以看出从左到右,版本的影响范围逐步降低,补丁的更新是无关痛痒的,并不会造成 API 的兼容性被破坏。
使用该规则,你还需要遵循一些基本规则:
- 使用标准的
x.y.z形式的版本号,例如1.0.0而不是1.0 - 在版本到达
1.0.0之前,怎么都行,但是如果有破坏性变更( breaking changes ),需要增加minor版本号。例如,为结构体新增字段或为枚举新增成员就是一种破坏性变更 - 在
1.0.0之后,如果发生破坏性变更,需要增加major版本号 - 在
1.0.0之后不要去破坏构建流程 - 在
1.0.0之后,不要在patch更新中添加新的api(pub声明),如果要添加新的pub结构体、特征、类型、函数、方法等对象时,增加minor版本号
如果大家想知道 Rust 如何使用版本号来解析依赖,可以查看这里。同时 SemVer 兼容性 提供了更为详尽的破坏性变更列表。
authors
[package]
authors = ["Sunfei <contact@im.dev>"]
该字段仅用于项目的元信息描述和 build.rs 用到的 CARGO_PKG_AUTHORS 环境变量,它并不会显示在 crates.io 界面上。
警告:清单中的
[package]部分一旦发布到crates.io就无法进行更改,因此对于已发布的包来说,authors字段是无法修改的
edition
可选字段,用于指定项目所使用的 Rust Edition。
该配置将影响项目中的所有 Cargo Target 和包,前者包含测试用例、benchmark、可执行文件、示例等。
[package]
# ...
edition = '2021'
大多数时候,我们都无需手动指定,因为 cargo new 的时候,会自动帮我们添加。若 edition 配置不存在,那 2015 Edition 会被默认使用。
rust-version
可选字段,用于说明你的项目支持的最低 Rust 版本(编译器能顺利完成编译)。一旦你使用的 Rust 版本比这个字段设置的要低,Cargo 就会报错,然后告诉用户所需的最低版本。
该字段是在 Rust 1.56 引入的,若大家使用的 Rust 版本低于该版本,则该字段会被自动忽略时。
[package]
# ...
edition = '2021'
rust-version = "1.56"
还有一点,rust-version 必须比第一个引入 edition 的 Rust 版本要新。例如 Rust Edition 2021 是在 Rust 1.56 版本引入的,若你使用了 edition = '2021' 的 [package] 配置,则指定的 rust version 字段必须要要大于等于 1.56 版本。
还可以使用 --ignore-rust-version 命令行参数来忽略 rust-version。
该字段将影响项目中的所有 Cargo Target 和包,前者包含测试用例、benchmark、可执行文件、示例等。
description
该字段是项目的简介,crates.io 会在项目首页使用该字段包含的内容,不支持 Markdown 格式。
[package]
# ...
description = "A short description of my package"
注意: 若发布
crates.io,则该字段是必须的
documentation
该字段用于说明项目文档的地址,若没有设置,crates.io 会自动链接到 docs.rs 上的相应页面。
[package]
# ...
documentation = "https://docs.rs/bitflags"
readme
readme 字段指向项目的 README.md 文件,该文件应该存在项目的根目录下(跟 Cargo.toml 同级),用于向用户描述项目的详细信息,支持 Markdown 格式。大家看到的 crates.io 上的项目首页就是基于该文件的内容进行渲染的。
[package]
# ...
readme = "README.md"
若该字段未设置且项目根目录下存在 README.md、README.txt 或 README 文件,则该文件的名称将被默认使用。
你也可以通过将 readme 设置为 false 来禁止该功能,若设置为 true ,则默认值 README.md 将被使用。
homepage
该字段用于设置项目主页的 URL:
[package]
# ...
homepage = "https://serde.rs/"
repository
设置项目的源代码仓库地址,例如 GitHub 链接:
[package]
# ...
repository = "https://github.com/rust-lang/cargo/"
license 和 license-file
license 字段用于描述项目所遵循的开源协议。而 license-file 则用于指定包含开源协议的文件所在的路径(相对于 Cargo.toml)。
如果要发布到 crates.io ,则该协议必须是 SPDX2.1 协议表达式。同时 license 名称必须是来自于 SPDX 协议列表 3.11。
SPDX 只支持使用 AND 、OR 来组合多个开源协议:
[package]
# ...
license = "MIT OR Apache-2.0"
OR 代表用户可以任选一个协议进行遵循,而 AND 表示用户必须要同时遵循两个协议。还可以通过 WITH 来在指定协议之外添加额外的要求:
MIT OR Apache-2.0LGPL-2.1-only AND MIT AND BSD-2-ClauseGPL-2.0-or-later WITH Bison-exception-2.2
若项目使用了非标准的协议,你可以通过指定 license-file 字段来替代 license 的使用:
[package]
# ...
license-file = "LICENSE.txt"
注意:crates.io 要求必须设置
license或license-file
keywords
该字段使用字符串数组的方式来指定项目的关键字列表,当用户在 crates.io 上搜索时,这些关键字可以提供索引的功能。
[package]
# ...
keywords = ["gamedev", "graphics"]
注意:
crates.io最多只支持 5 个关键字,每个关键字都必须是合法的ASCII文本,且需要使用字母作为开头,只能包含字母、数字、_和-,最多支持 20 个字符长度
categories
categories 用于描述项目所属的类别:
categories = ["command-line-utilities", "development-tools::cargo-plugins"]
注意:
crates.io最多只支持 5 个类别,目前不支持用户随意自定义类别,你所使用的类别需要跟 https://crates.io/category_slugs 上的类别精准匹配。
workspace
该字段用于配置当前项目所属的工作空间。
若没有设置,则将沿着文件目录向上寻找,直至找到第一个 设置了 [workspace] 的Cargo.toml。因此,当一个成员不在工作空间的子目录时,设置该字段将非常有用。
[package]
# ...
workspace = "path/to/workspace/root"
需要注意的是 Cargo.toml 清单还有一个 [workspace] 部分专门用于设置工作空间,若它被设置了,则 package 中的 workspace 字段将无法被指定。这是因为一个包无法同时满足两个角色:
- 该包是工作空间的根包(root crate),通过
[workspace]指定) - 该包是另一个工作空间的成员,通过
package.workspace指定
若要了解工作空间的更多信息,请参见这里。
build
build 用于指定位于项目根目录中的构建脚本,关于构建脚本的更多信息,可以阅读 构建脚本 一章。
[package]
# ...
build = "build.rs"
还可以使用 build = false 来禁止构建脚本的自动检测。
links
用于指定项目链接的本地库的名称,更多的信息请看构建脚本章节的 links
[package]
# ...
links = "foo"
exclude 和 include
这两个字段可以用于显式地指定想要包含在外或在内的文件列表,往往用于发布到注册服务时。你可以使用 cargo package --list 来检查哪些文件被包含在项目中。
[package]
# ...
exclude = ["/ci", "images/", ".*"]
[package]
# ...
include = ["/src", "COPYRIGHT", "/examples", "!/examples/big_example"]
尽管大家可能没有指定 include 或 exclude,但是任然会有些规则自动被应用,一起来看看。
若 include 没有被指定,则以下文件将被排除在外:
- 项目不是 git 仓库,则所有以
.开头的隐藏文件会被排除 - 项目是 git 仓库,通过
.gitignore配置的文件会被排除
无论 include 或 exclude 是否被指定,以下文件都会被排除在外:
- 任何包含
Cargo.toml的子目录会被排除 - 根目录下的
target目录会被排除
以下文件会永远被 include ,你无需显式地指定:
Cargo.toml- 若项目包含可执行文件或示例代码,则最小化的
Cargo.lock会自动被包含 license-file指定的协议文件
这两个字段很强大,但是对于生产实践而言,我们还是推荐通过
.gitignore来控制,因为这样协作者更容易看懂。如果大家希望更深入的了解include/exclude,可以参考下官方的Cargo文档
publish
该字段常常用于防止项目因为失误被发布到 crates.io 等注册服务上,例如如果希望项目在公司内部私有化,你应该设置:
[package]
# ...
publish = false
也可以通过字符串数组的方式来指定允许发布到的注册服务名称:
[package]
# ...
publish = ["some-registry-name"]
若 publish 数组中包含了一个注册服务名称,则 cargo publish 命令会使用该注册服务,除非你通过 --registry 来设定额外的规则。
metadata
Cargo 默认情况下会对 Cargo.toml 中未使用的 key 进行警告,以帮助大家提前发现风险。但是 package.metadata 并不在其中,因为它是由用户自定义的提供给外部工具的配置文件。例如:
[package]
name = "..."
# ...
# 以下配置元数据可以在生成安卓 APK 时使用
[package.metadata.android]
package-name = "my-awesome-android-app"
assets = "path/to/static"
与其相似的还有 [workspace.metadata],都可以作为外部工具的配置信息来使用。
default-run
当大家使用 cargo run 来运行项目时,该命令会使用默认的二进制可执行文件作为程序启动入口。
我们可以通过 default-run 来修改默认的入口,例如现在有两个二进制文件 src/bin/a.rs 和 src/bin/b.rs,通过以下配置可以将入口设置为前者:
[package]
default-run = "a"
[badges]
该部分用于指定项目当前的状态,该状态会展示在 crates.io 的项目主页中,例如以下配置可以设置项目的维护状态:
[badges]
# `maintenance` 是项目的当前维护状态,它可能会被其它注册服务所使用,但是目前还没有被 `crates.io` 使用: https://github.com/rust-lang/crates.io/issues/2437
#
# `status` 字段时必须的,以下是可用的选项:
# - `actively-developed`: 新特性正在积极添加中,bug 在持续修复中
# - `passively-maintained`: 目前没有计划去支持新的特性,但是项目维护者可能会回答你提出的 issue
# - `as-is`: 该项目的功能已经完结,维护者不准备继续开发和提供支持了,但是它的功能已经达到了预期
# - `experimental`: 作者希望同大家分享,但是还不准备满足任何人的特殊要求
# - `looking-for-maintainer`: 当前维护者希望将项目转移给新的维护者
# - `deprecated`: 不再推荐使用该项目,需要说明原因以及推荐的替代项目
# - `none`: 不显示任何 badge ,因此维护者没有说明他们的状态,用户需要自己去调查发生了什么
maintenance = { status = "..." }
[dependencies]
在之前章节中,我们已经详细介绍过 [dependencies] 、 [dev-dependencies] 和 [build-dependencies],这里就不再赘述。
[profile.*]
该部分可以对编译器进行配置,例如 debug 和优化,在后续的编译器优化章节有详细介绍。
Cargo Target
Cargo 项目中包含有一些对象,它们包含的源代码文件可以被编译成相应的包,这些对象被称之为 Cargo Target。例如之前章节提到的库对象 Library 、二进制对象 Binary、示例对象 Examples、测试对象 Tests 和 基准性能对象 Benches 都是 Cargo Target。
本章节我们一起来看看该如何在 Cargo.toml 清单中配置这些对象,当然,大部分时候都无需手动配置,因为默认的配置通常由项目目录的布局自动推断出来。
对象介绍
在开始讲解如何配置对象前,我们先来看看这些对象究竟是什么,估计还有些同学对此有些迷糊 :)
库对象(Library)
库对象用于定义一个库,该库可以被其它的库或者可执行文件所链接。该对象包含的默认文件名是 src/lib.rs,且默认情况下,库对象的名称跟项目名是一致的,
一个工程只能有一个库对象,因此也只能有一个 src/lib.rs 文件,以下是一种自定义配置:
# 一个简单的例子:在 Cargo.toml 中定制化库对象
[lib]
crate-type = ["cdylib"]
bench = false
二进制对象(Binaries)
二进制对象在被编译后可以生成可执行的文件,默认的文件名是 src/main.rs,二进制对象的名称跟项目名也是相同的。
大家应该还记得,一个项目拥有多个二进制文件,因此一个项目可以拥有多个二进制对象。当拥有多个对象时,对象的文件默认会被放在 src/bin/ 目录下。
二进制对象可以使用库对象提供的公共 API,也可以通过 [dependencies] 来引入外部的依赖库。
我们可以使用 cargo run --bin <bin-name> 的方式来运行指定的二进制对象,以下是二进制对象的配置示例:
# Example of customizing binaries in Cargo.toml.
[[bin]]
name = "cool-tool"
test = false
bench = false
[[bin]]
name = "frobnicator"
required-features = ["frobnicate"]
示例对象(Examples)
示例对象的文件在根目录下的 examples 目录中。既然是示例,自然是使用项目中的库对象的功能进行演示。示例对象编译后的文件会存储在 target/debug/examples 目录下。
如上所示,示例对象可以使用库对象的公共 API,也可以通过 [dependencies] 来引入外部的依赖库。
默认情况下,示例对象都是可执行的二进制文件( 带有 fn main() 函数入口),毕竟例子是用来测试和演示我们的库对象,是用来运行的。而你完全可以将示例对象改成库的类型:
[[example]]
name = "foo"
crate-type = ["staticlib"]
如果想要指定运行某个示例对象,可以使用 cargo run --example <example-name> 命令。如果是库类型的示例对象,则可以使用 cargo build --example <example-name> 进行构建。
与此类似,还可以使用 cargo install --example <example-name> 来将示例对象编译出的可执行文件安装到默认的目录中,将该目录添加到 $PATH 环境变量中,就可以直接全局运行安装的可执行文件。
最后,cargo test 命令默认会对示例对象进行编译,以防止示例代码因为长久没运行,导致严重过期以至于无法运行。
测试对象(Tests)
测试对象的文件位于根目录下的 tests 目录中,如果大家还有印象的话,就知道该目录是集成测试所使用的。
当运行 cargo test 时,里面的每个文件都会被编译成独立的包,然后被执行。
测试对象可以使用库对象提供的公共 API,也可以通过 [dependencies] 来引入外部的依赖库。
基准性能对象(Benches)
该对象的文件位于 benches 目录下,可以通过 cargo bench 命令来运行,关于基准测试,可以通过这篇文章了解更多。
配置一个对象
我们可以通过 Cargo.toml 中的 [lib]、[[bin]]、[[example]]、[[test]] 和 [[bench]] 部分对以上对象进行配置。
大家可能会疑惑
[lib]和[[bin]]的写法为何不一致,原因是这种语法是TOML提供的数组特性,[[bin]]这种写法意味着我们可以在 Cargo.toml 中创建多个[[bin]],每一个对应一个二进制文件上文提到过,我们只能指定一个库对象,因此这里只能使用
[lib]形式
由于它们的配置内容都是相似的,因此我们以 [lib] 为例来说明相应的配置项:
[lib]
name = "foo" # 对象名称: 库对象、`src/main.rs` 二进制对象的名称默认是项目名
path = "src/lib.rs" # 对象的源文件路径
test = true # 能否被测试,默认是 true
doctest = true # 文档测试是否开启,默认是 true
bench = true # 基准测试是否开启
doc = true # 文档功能是否开启
plugin = false # 是否可以用于编译器插件(deprecated).
proc-macro = false # 是否是过程宏类型的库
harness = true # 是否使用libtest harness : https://doc.rust-lang.org/stable/rustc/tests/index.html
edition = "2015" # 对象使用的 Rust Edition
crate-type = ["lib"] # 生成的包类型
required-features = [] # 构建对象所需的 Cargo Features (N/A for lib).
name
对于库对象和默认的二进制对象( src/main.rs ),默认的名称是项目的名称( package.name )。
对于其它类型的对象,默认是目录或文件名。
除了 [lib] 外,name 字段对于其他对象都是必须的。
proc-macro
该字段的使用方式在过程宏章节有详细的介绍。
edition
对使用的 Rust Edition 版本进行设置。
如果没有设置,则默认使用 [package] 中配置的 package.edition,通常来说,这个字段不应该被单独设置,只有在一些特殊场景中才可能用到:例如将一个大型项目逐步升级为新的 edition 版本。
crate-type
该字段定义了对象生成的包类型。它是一个数组,因此为同一个对象指定多个包类型。
需要注意的是,只有库对象和示例对象可以被指定,因为其他的二进制、测试和基准测试对象只能是 bin 这个包类型。
默认的包类型如下:
| 对象 | 包类型 |
|---|---|
| 正常的库对象 | "lib" |
| 过程宏的库对象 | "proc-macro" |
| 示例对象 | "bin" |
可用的选项包括 bin、lib、rlib、dylib、cdylib、staticlib 和 proc-macro ,如果大家想了解更多,可以看下官方的参考手册。
required-features
该字段用于指定在构建对象时所需的 features 列表。
该字段只对 [[bin]]、 [[bench]]、 [[test]] 和 [[example]] 有效,对于 [lib] 没有任何效果。
[features]
# ...
postgres = []
sqlite = []
tools = []
[[bin]]
name = "my-pg-tool"
required-features = ["postgres", "tools"]
对象自动发现
默认情况下,Cargo 会基于项目的目录文件布局自动发现和确定对象,而之前的配置项则允许我们对其进行手动的配置修改(若项目布局跟标准的不一样时)。
而这种自动发现对象的设定可以通过以下配置来禁用:
[package]
# ...
autobins = false
autoexamples = false
autotests = false
autobenches = false
只有在特定场景下才应该禁用自动对象发现。例如,你有一个模块想要命名为 bin,目录结构如下:
├── Cargo.toml
└── src
├── lib.rs
└── bin
└── mod.rs
这在默认情况下会导致问题,因为 Cargo 会使用 src/bin 作为存放二进制对象的地方。
为了阻止这一点,可以设置 autobins = false :
├── Cargo.toml
└── src
├── lib.rs
└── bin
└── mod.rs
工作空间 Workspace
一个工作空间是由多个 package 组成的集合,它们共享同一个 Cargo.lock 文件、输出目录和一些设置(例如 profiles : 编译器设置和优化)。组成工作空间的 packages 被称之为工作空间的成员。
工作空间的两种类型
工作空间有两种类型:root package 和虚拟清单( virtual manifest )。
根 package
若一个 package 的 Cargo.toml 包含了[package] 的同时又包含了 [workspace] 部分,则该 package 被称为工作空间的根 package。
换而言之,一个工作空间的根( root )是该工作空间的 Cargo.toml 文件所在的目录。
举个例子,我们现在有多个 package,它们的目录是嵌套关系,然后我们在最外层的 package,也就是最外层目录中的 Cargo.toml 中定义一个 [workspace],此时这个最外层的 package 就是工作空间的根。
再举个例子,大名鼎鼎的 ripgrep 就在最外层的 package 中定义了 [workspace] :
[workspace]
members = [
"crates/globset",
"crates/grep",
"crates/cli",
"crates/matcher",
"crates/pcre2",
"crates/printer",
"crates/regex",
"crates/searcher",
"crates/ignore",
]
那么最外层的目录就是 ripgrep 的工作空间的根。
虚拟清单
若一个 Cargo.toml 有 [workspace] 但是没有 [package] 部分,则它是虚拟清单类型的工作空间。
对于没有主 package 的场景或你希望将所有的 package 组织在单独的目录中时,这种方式就非常适合。
例如 rust-analyzer 就是这样的项目,它的根目录中的 Cargo.toml 中并没有 [package],说明该根目录不是一个 package,但是却有 [workspace] :
[workspace]
members = ["xtask/", "lib/*", "crates/*"]
exclude = ["crates/proc_macro_test/imp"]
结合 rust-analyzer 的目录布局可以看出,该工作空间的所有成员 package 都在单独的目录中,因此这种方式很适合虚拟清单的工作空间。
关键特性
工作空间的几个关键点在于:
- 所有的
package共享同一个Cargo.lock文件,该文件位于工作空间的根目录中 - 所有的
package共享同一个输出目录,该目录默认的名称是target,位于工作空间根目录下 - 只有工作空间根目录的
Cargo.toml才能包含[patch],[replace]和[profile.*],而成员的Cargo.toml中的相应部分将被自动忽略
[workspace]
Cargo.toml 中的 [workspace] 部分用于定义哪些 packages 属于工作空间的成员:
[workspace]
members = ["member1", "path/to/member2", "crates/*"]
exclude = ["crates/foo", "path/to/other"]
若某个本地依赖包是通过 path 引入,且该包位于工作空间的目录中,则该包自动成为工作空间的成员。
剩余的成员需要通过 workspace.members 来指定,里面包含了各个成员所在的目录(成员目录中包含了 Cargo.toml )。
members 还支持使用 glob 来匹配多个路径,例如上面的例子中使用 crates/* 匹配 crates 目录下的所有包。
exclude 可以将指定的目录排除在工作空间之外,例如还是上面的例子,crates/* 在包含了 crates 目录下的所有包后,又通过 exclude 中 crates/foo 将 crates 下的 foo 目录排除在外。
你也可以将一个空的 [workspace] 直接联合 [package] 使用,例如:
[package]
name = "hello"
version = "0.1.0"
[workspace]
此时的工作空间的成员包含:
- 根
package: "hello" - 所有通过
path引入的本地依赖(位于工作空间目录下)
选择工作空间
选择工作空间有两种方式:Cargo 自动查找、手动指定 package.workspace 字段。
当位于工作空间的子目录中时,Cargo 会自动在该目录的父目录中寻找带有 [workspace] 定义的 Cargo.toml,然后再决定使用哪个工作空间。
我们还可以使用下面的方法来覆盖 Cargo 自动查找功能:将成员包中的 package.workspace 字段修改为工作区间根目录的位置,这样就能显式地让一个成员使用指定的工作空间。
当成员不在工作空间的子目录下时,这种手动选择工作空间的方法就非常适用。毕竟 Cargo 的自动搜索是沿着父目录往上查找,而成员并不在工作空间的子目录下,这意味着顺着成员的父目录往上找是无法找到该工作空间的 Cargo.toml 的,此时就只能手动指定了。
选择 package
在工作空间中,package 相关的 Cargo 命令(例如 cargo build )可以使用 -p 、 --package 或 --workspace 命令行参数来指定想要操作的 package。
若没有指定任何参数,则 Cargo 将使用当前工作目录的中的 package 。若工作目录是虚拟清单类型的工作空间,则该命令将作用在所有成员上(就好像是使用了 --workspace 命令行参数)。而 default-members 可以在命令行参数没有被提供时,手动指定操作的成员:
[workspace]
members = ["path/to/member1", "path/to/member2", "path/to/member3/*"]
default-members = ["path/to/member2", "path/to/member3/foo"]
这样一来, cargo build 就不会应用到虚拟清单工作空间的所有成员,而是指定的成员上。
workspace.metadata
与 package.metadata 非常类似,workspace.metadata 会被 Cargo 自动忽略,就算没有被使用也不会发出警告。
这个部分可以用于让工具在 Cargo.toml 中存储一些工作空间的配置元信息。例如:
[workspace]
members = ["member1", "member2"]
[workspace.metadata.webcontents]
root = "path/to/webproject"
tool = ["npm", "run", "build"]
# ...
条件编译 Features
Cargo Feature 是非常强大的机制,可以为大家提供条件编译和可选依赖的高级特性。
[features]
Feature 可以通过 Cargo.toml 中的 [features] 部分来定义:其中每个 feature 通过列表的方式指定了它所能启用的其他 feature 或可选依赖。
假设我们有一个 2D 图像处理库,然后该库所支持的图片格式可以通过以下方式启用:
[features]
# 定义一个 feature : webp, 但它并没有启用其它 feature
webp = []
当定义了 webp 后,我们就可以在代码中通过 cfg 表达式来进行条件编译。例如项目中的 lib.rs 可以使用以下代码对 webp 模块进行条件引入:
#[cfg(feature = "webp")]
pub mod webp;
#[cfg(feature = "webp")] 的含义是:只有在 webp feature 被定义后,以下的 webp 模块才能被引入进来。由于我们之前在 [features] 里定义了 webp,因此以上代码的 webp 模块会被成功引入。
在 Cargo.toml 中定义的 feature 会被 Cargo 通过命令行参数 --cfg 传给 rustc,最终由后者完成编译:rustc --cfg ...。若项目中的代码想要测试 feature 是否存在,可以使用 cfg 属性或 cfg 宏。
之前我们提到了一个 feature 还可以开启其他 feature,举个例子,例如 ICO 图片格式包含 BMP 和 PNG 格式,因此当 ico 被启用后,它还得确保启用 bmp 和 png :
[features]
bmp = []
png = []
ico = ["bmp", "png"]
webp = []
对此,我们可以理解为: bmp 和 png 是开启 ico 的先决条件(注:开启 ico,会自动开启 bmp, png)。
Feature 名称可以包含来自 Unicode XID standard 定义的字母,允许使用 _ 或 0-9 的数字作为起始字符,在起始字符后,还可以使用 -、+ 或 . 。
但是我们还是推荐按照 crates.io 的方式来设置 Feature 名称 : crate.io 要求名称只能由 ASCII 字母数字、_、- 或 + 组成。
default feature
默认情况下,所有的 feature 都会被自动禁用,可以通过 default 来启用它们:
[features]
default = ["ico", "webp"]
bmp = []
png = []
ico = ["bmp", "png"]
webp = []
使用如上配置的项目被构建时,default feature 首先会被启用,然后它接着启用了 ico 和 webp feature,当然我们还可以关闭 default:
--no-default-features命令行参数可以禁用defaultfeaturedefault-features = false选项可以在依赖声明中指定
当你要去改变某个依赖库的
default启用的 feature 列表时(例如觉得该库引入的 feature 过多,导致最终编译出的文件过大),需要格外的小心,因为这可能会导致某些功能的缺失
可选依赖
当依赖被标记为 "可选 optional" 时,意味着它默认不会被编译。假设我们的 2D 图片处理库需要用到一个外部的包来处理 GIF 图片:
[dependencies]
gif = { version = "0.11.1", optional = true }
这种可选依赖的写法会自动定义一个与依赖同名的 feature,也就是 gif feature,这样一来,当我们启用 gif feature 时,该依赖库也会被自动引入并启用:例如通过 --feature gif 的方式启用 feature 。
注意:目前来说,
[feature]中定义的 feature 还不能与已引入的依赖库同名。但是在nightly中已经提供了实验性的功能用于改变这一点: namespaced features
当然,我们还可以通过显式定义 feature 的方式来启用这些可选依赖库,例如为了支持 AVIF 图片格式,我们需要引入两个依赖包,由于 avif 是通过 feature 引入的可选格式,因此它依赖的两个包也必须声明为可选的:
[dependencies]
ravif = { version = "0.6.3", optional = true }
rgb = { version = "0.8.25", optional = true }
[features]
avif = ["ravif", "rgb"]
之后,avif feature 一旦被启用,那这两个依赖库也将自动被引入。
注意:我们之前也讲过条件引入依赖的方法,那就是使用平台相关的依赖,与基于 feature 的可选依赖不同,它们是基于特定平台的可选依赖
依赖库自身的 feature
就像我们的项目可以定义 feature 一样,依赖库也可以定义它自己的 feature,也有需要启用的 feature 列表,当引入该依赖库时,我们可以通过以下方式为其启用相关的 features :
[dependencies]
serde = { version = "1.0.118", features = ["derive"] }
以上配置为 serde 依赖开启了 derive feature,还可以通过 default-features = false 来禁用依赖库的 default feature :
[dependencies]
flate2 = { version = "1.0.3", default-features = false, features = ["zlib"] }
这里我们禁用了 flate2 的 default feature,但又手动为它启用了 zlib feature。
注意:这种方式未必能成功禁用
default,原因是可能会有其它依赖也引入了flate2,并且没有对default进行禁用,那此时default依然会被启用。查看下文的 feature 同一化 获取更多信息
除此之外,还能通过下面的方式来间接开启依赖库的 feature :
[dependencies]
jpeg-decoder = { version = "0.1.20", default-features = false }
[features]
# Enables parallel processing support by enabling the "rayon" feature of jpeg-decoder.
parallel = ["jpeg-decoder/rayon"]
如上所示,我们定义了一个 parallel feature,同时为其启用了 jpeg-decoder 依赖的 rayon feature。
注意: 上面的 "package-name/feature-name" 语法形式不仅会开启指定依赖的指定 feature,若该依赖是可选依赖,那还会自动将其引入
在
nightly版本中,可以对这种行为进行禁用:weak dependency features
通过命令行参数启用 feature
以下的命令行参数可以启用指定的 feature :
--features FEATURES: 启用给出的 feature 列表,可以使用逗号或空格进行分隔,若你是在终端中使用,还需要加上双引号,例如--features "foo bar"。 若在工作空间中构建多个package,可以使用package-name/feature-name为特定的成员启用 features--all-features: 启用命令行上所选择的所有包的所有 features--no-default-features: 对选择的包禁用defaultfeature
feature 同一化
feature 只有在定义的包中才是唯一的,不同包之间的 feature 允许同名。因此,在一个包上启用 feature 不会导致另一个包的同名 feature 被误启用。
当一个依赖被多个包所使用时,这些包对该依赖所设置的 feature 将被进行合并,这样才能确保该依赖只有一个拷贝存在,这个过程就被称之为同一化。大家可以查看这里了解下解析器如何对 feature 进行解析处理。
这里,我们使用 winapi 为例来说明这个过程。首先,winapi 使用了大量的 features;然后我们有两个包 foo 和 bar 分别使用了它的两个 features,那么在合并后,最终 winapi 将同时启四个 features :
由于这种不可控性,我们需要让 启用feature = 添加特性 这个等式成立,换而言之,启用一个 feature 不应该导致某个功能被禁止。这样才能的让多个包启用同一个依赖的不同 features。
例如,如果我们想可选的支持 no_std 环境(不使用标准库),那么有两种做法:
- 默认代码使用标准库的,当
no_stdfeature 启用时,禁用相关的标准库代码 - 默认代码使用非标准库的,当
stdfeature 启用时,才使用标准库的代码
前者就是功能削减,与之相对,后者是功能添加,根据之前的内容,我们应该选择后者的做法:
#![allow(unused)] #![no_std] fn main() { #[cfg(feature = "std")] extern crate std; #[cfg(feature = "std")] pub fn function_that_requires_std() { // ... } }
彼此互斥的 feature
某极少数情况下,features 之间可能会互相不兼容。我们应该避免这种设计,因为如果一旦这么设计了,那你可能需要修改依赖图的很多地方才能避免两个不兼容 feature 的同时启用。
如果实在没有办法,可以考虑增加一个编译错误来让报错更清晰:
#[cfg(all(feature = "foo", feature = "bar"))]
compile_error!("feature \"foo\" and feature \"bar\" cannot be enabled at the same time");
当同时启用 foo 和 bar 时,编译器就会爆出一个更清晰的错误:feature foo 和 bar 无法同时启用。
总之,我们还是应该在设计上避免这种情况的发生,例如:
- 将某个功能分割到多个包中
- 当冲突时,设置 feature 优先级,cfg-if 包可以帮助我们写出更复杂的
cfg表达式
检视已解析的 features
在复杂的依赖图中,如果想要了解不同的 features 是如何被多个包多启用的,这是相当困难的。好在 cargo tree 命令提供了几个选项可以帮组我们更好的检视哪些 features 被启用了:
cargo tree -e features ,该命令以依赖图的方式来展示已启用的 features,包含了每个依赖包所启用的特性:
$ cargo tree -e features
test_cargo v0.1.0 (/Users/sunfei/development/rust/demos/test_cargo)
└── uuid feature "default"
├── uuid v0.8.2
└── uuid feature "std"
└── uuid v0.8.2
cargo tree -f "{p} {f}" 命令会提供一个更加紧凑的视图:
$ cargo tree -f "{p} {f}"
test_cargo v0.1.0 (/Users/sunfei/development/rust/demos/test_cargo)
└── uuid v0.8.2 default,std
cargo tree -e features -i foo,该命令会显示 features 会如何"流入"指定的包 foo 中:
$ cargo tree -e features -i uuid
uuid v0.8.2
├── uuid feature "default"
│ └── test_cargo v0.1.0 (/Users/sunfei/development/rust/demos/test_cargo)
│ └── test_cargo feature "default" (command-line)
└── uuid feature "std"
└── uuid feature "default" (*)
该命令在依赖图较为复杂时非常有用,使用它可以让你了解某个依赖包上开启了哪些 features 以及其中的原因。
大家可以查看官方的 cargo tree 文档获取更加详细的使用信息。
Feature 解析器 V2 版本
我们还能通过以下配置指定使用 V2 版本的解析器( resolver ):
[package]
name = "my-package"
version = "1.0.0"
resolver = "2"
V2 版本的解析器可以在某些情况下避免 feature 同一化的发生,具体的情况在这里有描述,下面做下简单的总结:
- 为特定平台开启的
features且此时并没有被构建,会被忽略 build-dependencies和proc-macros不再跟普通的依赖共享featuresdev-dependencies的features不会被启用,除非正在构建的对象需要它们(例如测试对象、示例对象等)
对于部分场景而言,feature 同一化确实是需要避免的,例如,一个构建依赖开启了 std feature,而同一个依赖又被用于 no_std 环境,很明显,开启 std 将导致错误的发生。
说完优点,我们再来看看 V2 的缺点,其中增加编译构建时间就是其中之一,原因是同一个依赖会被构建多次(每个都拥有不同的 feature 列表)。
由于此部分内容可能只有极少数的用户需要,因此我们并没有对其进行扩展,如果大家希望了解更多关于 V2 的内容,可以查看官方文档
构建脚本
构建脚本可以通过 CARGO_FEATURE_<name> 环境变量获取启用的 feature 列表,其中 <name> 是 feature 的名称,该名称被转换成大全写字母,且 - 被转换为 _。
required-features
该字段可以用于禁用特定的 Cargo Target:当某个 feature 没有被启用时,查看这里获取更多信息。
SemVer 兼容性
启用一个 feature 不应该引入一个不兼容 SemVer 的改变。例如,启用的 feature 不应该改变现有的 API,因为这会给用户造成不兼容的破坏性变更。 如果大家想知道哪些变化是兼容的,可以参见官方文档。
总之,在新增/移除 feature 或可选依赖时,你需要小心,因此这些可能会造成向后不兼容性。更多信息参见这里,简单总结如下:
- 在发布
minor版本时,以下通常是安全的: - 在发布
minor版本时,以下操作应该避免:
feature 文档和发现
将你的项目支持的 feature 信息写入到文档中是非常好的选择:
- 我们可以通过在
lib.rs的顶部添加文档注释的方式来实现。例如regex就是这么做的。 - 若项目拥有一个用户手册,那也可以在那里添加说明,例如 serde.rs。
- 若项目是二进制类型(可运行的应用服务,包含
fn main入口),可以将说明放在README文件或其他文档中,例如 sccache。
特别是对于不稳定的或者不该再被使用的 feature 而言,它们更应该被放在文档中进行清晰的说明。
当构建发布到 docs.rs 上的文档时,会使用 Cargo.toml 中的元数据来控制哪些 features 会被启用。查看 docs.rs 文档获取更多信息。
如何发现 features
若依赖库的文档中对其使用的 features 做了详细描述,那你会更容易知道他们使用了哪些 features 以及该如何使用。
当依赖库的文档没有相关信息时,你也可以通过源码仓库的 Cargo.toml 文件来获取,但是有些时候,使用这种方式来跟踪并获取全部相关的信息是相当困难的。
Features 示例
以下我们一起来看看一些来自真实世界的示例。
最小化构建时间和文件大小
如果一些包的部分特性不再启用,就可以减少该包占用的大小以及编译时间:
syn包可以用来解析 Rust 代码,由于它很受欢迎,大量的项目都在引用,因此它给出了非常清晰的文档关于如何最小化使用它包含的featuresregex也有关于 features 的描述文档,例如移除 Unicode 支持的 feature 可以降低最终生成可执行文件的大小winapi拥有众多 features,这些feature对用了各种 Windows API,你可以只引入代码中用到的 API 所对应的 feature.
行为扩展
serde_json 拥有一个 preserve_order feature,可以用于在序列化时保留 JSON 键值对的顺序。同时,该 feature 还会启用一个可选依赖 indexmap。
当这么做时,一定要小心不要破坏了 SemVer 的版本兼容性,也就是说:启用 feature 后,代码依然要能正常工作。
no_std 支持
一些包希望能同时支持 no_std 和 std 环境,例如该包希望支持嵌入式系统或资源紧张的系统,且又希望能支持其它的平台,此时这种做法是非常有用的,因为标准库 std 会大幅增加编译出来的文件的大小,对于资源紧张的系统来说,no_std 才是最合适的。
wasm-bindgen 定义了一个 std feature,它是默认启用的。首先,在库的顶部,它无条件的启用了 no_std 属性,它可以确保 std 和 std prelude 不会自动引入到作用域中来。其次,在不同的地方(示例 1,示例 2),它通过 #[cfg(feature = "std")] 启用 std feature 来添加 std 标准库支持。
对依赖库的 features 进行再导出
从依赖库再导出 features 在有些场景中会相当有用,这样用户就可以通过依赖包的 features 来控制功能而不是自己去手动定义。
例如 regex 将 regex_syntax 包的 features 进行了再导出,这样 regex 的用户无需知道 regex_syntax 包,但是依然可以访问后者包含的 features。
feature 优先级
一些包可能会拥有彼此互斥的 features(无法共存,上一章节中有讲到),其中一个办法就是为 feature 定义优先级,这样其中一个就会优于另一个被启用。
例如 log 包,它有几个 features 可以用于在编译期选择最大的日志级别,这里,它就使用了 cfg-if 的方式来设置优先级。一旦多个 features 被启用,那更高优先级的就会优先被启用。
过程宏包
一些包拥有过程宏,这些宏必须定义在一个独立的包中。但是不是所有的用户都需要过程宏的,因此也无需引入该包。
在这种情况下,将过程宏所在的包定义为可选依赖,是很不错的选择。这样做还有一个好处:有时过程宏的版本必须要跟父包进行同步,但是我们又不希望所有的用户都进行同步。
其中一个例子就是 serde ,它有一个 derive feature 可以启用 serde_derive 过程宏。由于 serde_derive 包跟 serde 的关系非常紧密,因此它使用了版本相同的需求来保证两者的版本同步性。
只能用于 nightly 的 feature
Rust 有些实验性的 API 或语言特性只能在 nightly 版本下使用,但某些使用了这些 API 的包并不想强制他们的用户也使用 nightly 版本,因此他们会通过 feature 的方式来控制。
若用户希望使用这些 API 时,需要启用相应的 feature ,而这些 feature 只能在 nightly 下使用。若用户不需要使用这些 API,就无需开启 相应的 feature,自然也不需要使用 nightly 版本。
例如 rand 包有一个 simd_support feature 就只能在 nightly 下使用,若我们不使用该 feature,则在 stable 下依然可以使用 rand。
实验性 feature
有一些包会提前将一些实验性的 API 放出去,既然是实验性的,自然无法保证其稳定性。在这种情况下,通常会在文档中将相应的 features 标记为实验性,意味着它们在未来可能会发生大的改变(甚至 minor 版本都可能发生)。
其中一个例子是 async-std 包,它拥有一个 unstable feature,用来标记一些新的 API,表示人们已经可以选择性的使用但是还没有准备好去依赖它。
发布配置 Profile
细心的同学可能发现了迄今为止我们已经为 Cargo 引入了不少新的名词,而且这些名词有一个共同的特点,不容易或不适合翻译成中文,因为难以表达的很准确,例如 Cargo Target, Feature 等,这不现在又多了一个 Profile。
默认的 profile
Profile 其实是一种发布配置,例如它默认包含四种: dev、 release、 test 和 bench,正常情况下,我们无需去指定,Cargo 会根据我们使用的命令来自动进行选择
- 例如
cargo build自动选择devprofile,而cargo test则是testprofile, 出于历史原因,这两个 profile 输出的结果都存放在项目根目录下的target/debug目录中,结果往往用于开发/测试环境 - 而
cargo build --release自动选择releaseprofile,并将输出结果存放在target/release目录中,结果往往用于生产环境
可以看出 Profile 跟 Nodejs 的 dev 和 prod 很像,都是通过不同的配置来为目标环境构建最终编译后的结果: dev 编译输出的结果用于开发环境,prod 则用于生产环境。
针对不同的 profile,编译器还会提供不同的优化级别,例如 dev 用于开发环境,因此构建速度是最重要的:此时,我们可以牺牲运行性能来换取编译性能,那么优化级别就会使用最低的。而 release 则相反,优化级别会使用最高,导致的结果就是运行得非常快,但是编译速度大幅降低。
初学者一个常见的错误,就是使用非
releaseprofile 去测试性能,例如cargo run,这种方式显然无法得到正确的结果,我们应该使用cargo run --release的方式测试性能
profile 可以通过 Cargo.toml 中的 [profile] 部分进行设置和改变:
[profile.dev]
opt-level = 1 # 使用稍高一些的优化级别,最低是0,最高是3
overflow-checks = false # 关闭整数溢出检查
需要注意的是,每一种 profile 都可以单独的进行设置,例如上面的 [profile.dev]。
如果是工作空间的话,只有根 package 的 Cargo.toml 中的 [profile] 设置才会被使用,其它成员或依赖包中的设置会被自动忽略。
另外,profile 还能在 Cargo 自身的配置文件中进行覆盖,总之,通过 .cargo/config.toml 或环境变量的方式所指定的 profile 配置会覆盖项目的 Cargo.toml 中相应的配置。
自定义 profile
除了默认的四种 profile,我们还可以定义自己的。对于大公司来说,这个可能会非常有用,自定义的 profile 可以帮助我们建立更灵活的工作发布流和构建模型。
当定义 profile 时,你必须指定 inherits 用于说明当配置缺失时,该 profile 要从哪个 profile 那里继承配置。
例如,我们想在 release profile 的基础上增加 LTO 优化,那么可以在 Cargo.toml 中添加如下内容:
[profile.release-lto]
inherits = "release"
lto = true
然后在构建时使用 --profile 来指定想要选择的自定义 profile :
$ cargo build --profile release-lto
与默认的 profile 相同,自定义 profile 的编译结果也存放在 target/ 下的同名目录中,例如 --profile release-lto 的输出结果存储在 target/release-lto 中。
选择 profile
- 默认使用
dev:cargo build,cargo rustc,cargo check, 和cargo run - 默认使用
test:cargo test - 默认使用
bench:cargo bench - 默认使用
release:cargo install,cargo build --release,cargo run --release - 使用自定义 profile:
cargo build --profile release-lto
profile 设置
下面我们来看看 profile 中可以进行哪些优化设置。
opt-level
该字段用于控制 -C opt-level 标志的优化级别。更高的优化级别往往意味着运行更快的代码,但是也意味着更慢的编译速度。
同时,更高的编译级别甚至会造成编译代码的改变和再排列,这会为 debug 带来更高的复杂度。
opt-level 支持的选项包括:
0: 无优化1: 基本优化2: 一些优化3: 全部优化- "s": 优化输出的二进制文件的大小
- "z": 优化二进制文件大小,但也会关闭循环向量化
我们非常推荐你根据自己的需求来找到最适合的优化级别(例如,平衡运行和编译速度)。而且有一点值得注意,有的时候优化级别和性能的关系可能会出乎你的意料之外,例如 3 比 2 更慢,再比如 "s" 并没有让你的二进制文件变得更小。
而且随着 rustc 版本的更新,你之前的配置也可能要随之变化,总之,为项目的热点路径做好基准性能测试是不错的选择,不然总不能每次都手动重写代码来测试吧 :)
如果想要了解更多,可以参考 rustc 文档,这里有更高级的优化技巧。
debug
debug 控制 -C debuginfo 标志,而后者用于控制最终二进制文件输出的 debug 信息量。
支持的选项包括:
0或false:不输出任何 debug 信息1: 行信息2: 完整的 debug 信息
split-debuginfo
split-debuginfo 控制 -C split-debuginfo 标志,用于决定输出的 debug 信息是存放在二进制可执行文件里还是邻近的文件中。
debug-assertions
该字段控制 -C debug-assertions 标志,可以开启或关闭其中一个条件编译选项: cfg(debug_assertions)。
debug-assertion 会提供运行时的检查,该检查只能用于 debug 模式,原因是对于 release 来说,这种检查的成本较为高昂。
大家熟悉的 debug_assert! 宏也是通过该标志开启的。
支持的选项包括 :
true: 开启false: 关闭
overflow-checks
用于控制 -C overflow-checks 标志,该标志可以控制运行时的整数溢出行为。当开启后,整数溢出会导致 panic。
支持的选项包括 :
true: 开启false: 关闭
lto
lto 用于控制 -C lto 标志,而后者可以控制 LLVM 的链接时优化( link time optimizations )。通过对整个程序进行分析,并以增加链接时间为代价,LTO 可以生成更加优化的代码。
支持的选项包括:
false: 只会对代码生成单元中的本地包进行"thin" LTO优化,若代码生成单元数为 1 或者opt-level为 0,则不会进行任何 LTO 优化true或"fat":对依赖图中的所有包进行"fat" LTO优化"thin":对依赖图的所有包进行"thin" LTO,相比"fat"来说,它仅牺牲了一点性能,但是换来了链接时间的可观减少off: 禁用 LTO
如果大家想了解跨语言 LTO,可以看下 -C linker-plugin-lto 标志。
panic
panic 控制 -C panic 标志,它可以控制 panic 策略的选择。
支持的选项包括:
"unwind": 遇到 panic 后对栈进行展开( unwind )"abort": 遇到 panic 后直接停止程序
当设置为 "unwind" 时,具体的栈展开信息取决于特定的平台,例如 NVPTX 不支持 unwind,因此程序只能 "abort"。
测试、基准性能测试、构建脚本和过程宏会忽略 panic 设置,目前来说它们要求是 "unwind",如果大家希望修改成 "abort",可以看看 panic-abort-tests 。
另外,当你使用 "abort" 策略且在执行测试时,由于上述的要求,除了测试代码外,所有的依赖库也会忽略该 "abort" 设置而使用 "unwind" 策略。
incremental
incremental 控制 -C incremental 标志,用于开启或关闭增量编译。开启增量编译时,rustc 会将必要的信息存放到硬盘中( target 目录中 ),当下次编译时,这些信息可以被复用以改善编译时间。
支持的选项包括:
true: 启用false: 关闭
增量编译只能用于工作空间的成员和通过 path 引入的本地依赖。
大家还可以通过环境变量 CARGO_INCREMENTAL 或 Cargo 配置 build.incremental 在全局对 incremental 进行覆盖。
codegen-units
codegen-units 控制 -C codegen-units 标志,可以指定一个包会被分隔为多少个代码生成单元。更多的代码生成单元会提升代码的并行编译速度,但是可能会降低运行速度。
对于增量编译,默认值是 256,非增量编译是 16。
r-path
用于控制 -C rpath标志,可以控制 rpath 的启用与关闭。
rpath 代表硬编码到二进制可执行文件或库文件中的运行时代码搜索(runtime search path),动态链接库的加载器就通过它来搜索所需的库。
默认 profile
dev
dev profile 往往用于开发和 debug,cargo build 或 cargo run 默认使用的就是 dev profile,cargo build --debug 也是。
注意:
devprofile 的结果并没有输出到target/dev同名目录下,而是target/debug,这是历史遗留问题
默认的 dev profile 设置如下:
[profile.dev]
opt-level = 0
debug = true
split-debuginfo = '...' # Platform-specific.
debug-assertions = true
overflow-checks = true
lto = false
panic = 'unwind'
incremental = true
codegen-units = 256
rpath = false
release
release 往往用于预发/生产环境或性能测试,以下命令使用的就是 release profile:
cargo build --releasecargo run --releasecargo install
默认的 release profile 设置如下:
[profile.release]
opt-level = 3
debug = false
split-debuginfo = '...' # Platform-specific.
debug-assertions = false
overflow-checks = false
lto = false
panic = 'unwind'
incremental = false
codegen-units = 16
rpath = false
test
该 profile 用于构建测试,它的设置是继承自 dev
bench
bench profile 用于构建基准测试 benchmark,它的设计默认继承自 release
构建本身依赖
默认情况下,所有的 profile 都不会对构建过程本身所需的依赖进行优化,构建过程本身包括构建脚本、过程宏。
默认的设置是:
[profile.dev.build-override]
opt-level = 0
codegen-units = 256
[profile.release.build-override]
opt-level = 0
codegen-units = 256
如果是自定义 profile,那它会自动从当前正在使用的 profile 继承相应的设置,但不会修改。
重写 profile
我们还可以对特定的包使用的 profile 进行重写(override):
# `foo` package 将使用 -Copt-level=3 标志.
[profile.dev.package.foo]
opt-level = 3
这里的 package 名称实际上是一个 Package ID,因此我们还可以通过版本号来选择: [profile.dev.package."foo:2.1.0"]。
如果要为所有依赖包重写(不包括工作空间的成员):
[profile.dev.package."*"]
opt-level = 2
为构建脚本、过程宏和它们的依赖重写:
[profile.dev.build-override]
opt-level = 3
注意:如果一个依赖同时被正常代码和构建脚本所使用,当
--target没有指定时,Cargo 只会构建该依赖一次。但是当使用了
build-override后,该依赖会被构建两次,一次为正常代码,一次为构建脚本,因此会增加一些编译时间
重写的优先级按以下顺序执行(第一个匹配获胜):
[profile.dev.package.name],指定名称进行重写[profile.dev.package."*"],对所有非工作空间成员的 package 进行重写[profile.dev.build-override],对构建脚本、过程宏及它们的依赖进行重写[profile.dev]- Cargo 内置的默认值
重写无法使用 panic、lto 或 rpath 设置。
通过 config.toml 对 Cargo 进行配置
Cargo 相关的配置有两种,第一种是对自身进行配置,第二种是对指定的项目进行配置,关于后者请查看 Cargo.toml 清单。对于普通用户而言第二种才是我们最常使用的。
本文讲述的是如何对 Cargo 相关的工具进行配置,该配置中的部分内容可能会覆盖掉 Cargo.toml 中对应的部分,例如关于 profile 的内容。
层级结构
在前面我们已经见识过如何为 Cargo 进行全局配置:$HOME/.cargo/config.toml,事实上,还支持在一个 package 内对它进行配置。
总体原则是:Cargo 会顺着当前目录往上查找,直到找到目标配置文件。例如我们在目录 /projects/foo/bar/baz 下调用 Cargo 命令,那查找路径如下所示:
/projects/foo/bar/baz/.cargo/config.toml/projects/foo/bar/.cargo/config.toml/projects/foo/.cargo/config.toml/projects/.cargo/config.toml/.cargo/config.toml$CARGO_HOME/config.toml默认是 :- Windows:
%USERPROFILE%\.cargo\config.toml - Unix:
$HOME/.cargo/config.toml
- Windows:
有了这种机制,我们既可以在全局中设置默认的配置,又可以每个包都设定独立的配置,甚至还能做版本控制。
如果一个 key 在多个配置中出现,那这些 key 只会保留一个:最靠近 Cargo 执行目录的配置文件中的 key 的值将被最终使用(因此, HOME 下的都是最低优先级)。需要注意的是,如果 key 的值是数组,那相应的值将被合并( join )。
对于工作空间而言,Cargo 的搜索策略是从 root 开始,对于内部成员中包含的 .cargo.toml 会自动忽略。例如一个工作空间拥有两个成员,每个成员都有配置文件: /projects/foo/bar/baz/mylib/.cargo/config.toml 和 /projects/foo/bar/baz/mybin/.cargo/config.toml,但是 Cargo 并不会读取它们而是从工作空间的根( /projects/foo/bar/baz/ )开始往上查找。
注意:Cargo 还支持没有
.toml后缀的.cargo/config文件。对于.toml的支持是从 Rust 1.39 版本开始,同时也是目前最推荐的方式。但若同时存在有后缀和无后缀的文件,Cargo 将使用无后缀的!
配置文件概览
下面是一个完整的配置文件,并对常用的选项进行了翻译,大家可以参考下:
paths = ["/path/to/override"] # 覆盖 `Cargo.toml` 中通过 path 引入的本地依赖
[alias] # 命令别名
b = "build"
c = "check"
t = "test"
r = "run"
rr = "run --release"
space_example = ["run", "--release", "--", "\"command list\""]
[build]
jobs = 1 # 并行构建任务的数量,默认等于 CPU 的核心数
rustc = "rustc" # rust 编译器
rustc-wrapper = "…" # 使用该 wrapper 来替代 rustc
rustc-workspace-wrapper = "…" # 为工作空间的成员使用 该 wrapper 来替代 rustc
rustdoc = "rustdoc" # 文档生成工具
target = "triple" # 为 target triple 构建 ( `cargo install` 会忽略该选项)
target-dir = "target" # 存放编译输出结果的目录
rustflags = ["…", "…"] # 自定义flags,会传递给所有的编译器命令调用
rustdocflags = ["…", "…"] # 自定义flags,传递给 rustdoc
incremental = true # 是否开启增量编译
dep-info-basedir = "…" # path for the base directory for targets in depfiles
pipelining = true # rustc pipelining
[doc]
browser = "chromium" # `cargo doc --open` 使用的浏览器,
# 可以通过 `BROWSER` 环境变量进行重写
[env]
# Set ENV_VAR_NAME=value for any process run by Cargo
ENV_VAR_NAME = "value"
# Set even if already present in environment
ENV_VAR_NAME_2 = { value = "value", force = true }
# Value is relative to .cargo directory containing `config.toml`, make absolute
ENV_VAR_NAME_3 = { value = "relative/path", relative = true }
[cargo-new]
vcs = "none" # 所使用的 VCS ('git', 'hg', 'pijul', 'fossil', 'none')
[http]
debug = false # HTTP debugging
proxy = "host:port" # HTTP 代理,libcurl 格式
ssl-version = "tlsv1.3" # TLS version to use
ssl-version.max = "tlsv1.3" # 最高支持的 TLS 版本
ssl-version.min = "tlsv1.1" # 最小支持的 TLS 版本
timeout = 30 # HTTP 请求的超时时间,秒
low-speed-limit = 10 # 网络超时阈值 (bytes/sec)
cainfo = "cert.pem" # path to Certificate Authority (CA) bundle
check-revoke = true # check for SSL certificate revocation
multiplexing = true # HTTP/2 multiplexing
user-agent = "…" # the user-agent header
[install]
root = "/some/path" # `cargo install` 安装到的目标目录
[net]
retry = 2 # 网络重试次数
git-fetch-with-cli = true # 是否使用 `git` 命令来执行 git 操作
offline = true # 不能访问网络
[patch.<registry>]
# Same keys as for [patch] in Cargo.toml
[profile.<name>] # profile 配置,详情见"如何在 Cargo.toml 中配置 profile" : https://course.rs/cargo/reference/profiles.html#profile设置
opt-level = 0
debug = true
split-debuginfo = '...'
debug-assertions = true
overflow-checks = true
lto = false
panic = 'unwind'
incremental = true
codegen-units = 16
rpath = false
[profile.<name>.build-override]
[profile.<name>.package.<name>]
[registries.<name>] # 设置其它的注册服务: https://course.rs/cargo/reference/specify-deps.html#从其它注册服务引入依赖包
index = "…" # 注册服务索引列表的 URL
token = "…" # 连接注册服务所需的鉴权 token
[registry]
default = "…" # 默认的注册服务名称: crates.io
token = "…"
[source.<name>] # 注册服务源和替换source definition and replacement
replace-with = "…" # 使用给定的 source 来替换当前的 source,例如使用科大源来替换crates.io源以提升国内的下载速度:[source.crates-io] replace-with = 'ustc'
directory = "…" # path to a directory source
registry = "…" # 注册源的 URL ,例如科大源: [source.ustc] registry = "git://mirrors.ustc.edu.cn/crates.io-index"
local-registry = "…" # path to a local registry source
git = "…" # URL of a git repository source
branch = "…" # branch name for the git repository
tag = "…" # tag name for the git repository
rev = "…" # revision for the git repository
[target.<triple>]
linker = "…" # linker to use
runner = "…" # wrapper to run executables
rustflags = ["…", "…"] # custom flags for `rustc`
[target.<cfg>]
runner = "…" # wrapper to run executables
rustflags = ["…", "…"] # custom flags for `rustc`
[target.<triple>.<links>] # `links` build script override
rustc-link-lib = ["foo"]
rustc-link-search = ["/path/to/foo"]
rustc-flags = ["-L", "/some/path"]
rustc-cfg = ['key="value"']
rustc-env = {key = "value"}
rustc-cdylib-link-arg = ["…"]
metadata_key1 = "value"
metadata_key2 = "value"
[term]
verbose = false # whether cargo provides verbose output
color = 'auto' # whether cargo colorizes output
progress.when = 'auto' # whether cargo shows progress bar
progress.width = 80 # width of progress bar
环境变量
除了 config.toml 配置文件,我们还可以使用环境变量的方式对 Cargo 进行配置。
配置文件的中的 key foo.bar 对应的环境变量形式为 CARGO_FOO_BAR,其中的.、- 被转换成 _,且字母都变成大写的。例如,target.x86_64-unknown-linux-gnu.runner key 转换成环境变量后变成 CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER。
就优先级而言,环境变量是比配置文件更高的。除了上面的机制,Cargo 还支持一些预定义的环境变量。
官方 Cargo Book 中本文的内容还有很多,但是剩余内容对于绝大多数用户都用不到,因此我们并没有涵盖其中。
发布到 crates.io
如果你想要把自己的开源项目分享给全世界,那最好的办法自然是 GitHub。但如果是 Rust 的库,那除了发布到 GitHub 外,我们还可以将其发布到 crates.io 上,然后其它用户就可以很简单的对其进行引用。
注意:发布包到
crates.io后,特定的版本无法被覆盖,要发布就必须使用新的版本号,代码也无法被删除!
首次发布之前
首先,我们需要一个账号:访问 crates.io 的主页,然后在右上角使用 GitHub 账户登陆,接着访问你的账户设置页面,进入到 API Tokens 标签页下,生成新的 Token,并使用该 Token 在终端中进行登录:
$ cargo login abcdefghijklmnopqrstuvwxyz012345
该命令将告诉 Cargo 你的 API Token,然后将其存储在本地的 ~/.cargo/credentials.toml 文件中。
注意:你需要妥善保管好 API Token,并且不要告诉任何人,一旦泄漏,请撤销( Revoke )并重新生成。
发布包之前
crates.io 上的包名遵循先到先得的方式:一旦你想要的包名已经被使用,那么你就得换一个不同的包名。
在发布之前,确保 Cargo.toml 中以下字段已经被设置:
你还可以设置关键字和类别等元信息,让包更容易被其他人搜索发现,虽然它们不是必须的。
如果你发布的是一个依赖库,那么你可能需要遵循相关的命名规范和 API Guidlines.
打包
下一步就是将你的项目进行打包,然后上传到 crates.io。为了实现这个目的,我们可以使用 cargo publish 命令,该命令执行了以下步骤:
- 对项目进行一些验证
- 将源代码压缩到
.crate文件中 - 将
.crate文件解压并放入到临时的目录中,并验证解压出的代码可以顺利编译 - 上传
.crate文件到crates.io - 注册服务会对上传的包进行一些额外的验证,然后才会添加它到注册服务列表中
在发布之前,我们推荐你先运行 cargo publish --dry-run (或 cargo package ) 命令来确保代码没有 warning 或错误。
$ cargo publish --dry-run
你可以在 target/package 目录下观察生成的 .crate 文件。例如,目前 crates.io 要求该文件的大小不能超过 10MB,你可以通过手动检查该文件的大小来确保不会无意间打包进一些较大的资源文件,比如测试数据、网站文档或生成的代码等。我们还可以使用以下命令来检查其中包含的文件:
$ cargo package --list
当打包时,Cargo 会自动根据版本控制系统的配置来忽略指定的文件,例如 .gitignore。除此之外,你还可以通过 exclude 来排除指定的文件:
[package]
# ...
exclude = [
"public/assets/*",
"videos/*",
]
如果想要显式地将某些文件包含其中,可以使用 include,但是需要注意的是,这个 key 一旦设置,那 exclude 就将失效:
[package]
# ...
include = [
"**/*.rs",
"Cargo.toml",
]
上传包
准备好后,我们就可以正式来上传指定的包了,在根目录中运行:
$ cargo publish
就是这么简单,恭喜你,完成了第一个包的发布!
发布已上传包的新版本
绝大多数时候,我们并不是在发布新包,而是发布已经上传过的包的新版本。
为了实现这一点,只需修改 Cargo.toml 中的 version 字段 ,但需要注意:版本号需要遵循 semver 规则。
然后再次使用 cargo publish 就可以上传新的版本了。
管理 crates.io 上的包
目前来说,管理包更多地是通过 cargo 命令而不是在线管理,下面是一些你可以使用的命令。
cargo yank
有的时候你会遇到发布的包版本实际上并不可用(例如语法错误,或者忘记包含一个文件等),对于这种情况,Cargo 提供了 yank 命令:
$ cargo yank --vers 1.0.1
$ cargo yank --vers 1.0.1 --undo
该命令并不能删除任何代码,例如如果你上传了一段隐私内容,你需要的是立刻重置它们,而不是使用 cargo yank。
yank 能做到的就是让其它人不能再使用这个版本作为依赖,但是现存的依赖依然可以继续工作。crates.io 的一个主要目标就是作为一个不会随着时间变化的永久性包存档,但删除某个版本显然违背了这个目标。
cargo owner
一个包可能会有多个主要开发者,甚至维护者 maintainer 都会发生变更。目前来说,只有包的 owner 才能发布新的版本,但是一个 owner 可以指定其它的用户为 owner:
$ cargo owner --add github-handle
$ cargo owner --remove github-handle
$ cargo owner --add github:rust-lang:owners
$ cargo owner --remove github:rust-lang:owners
命令中使用的 ownerID 必须是 GitHub 用户名或 Team 名。
一旦一个用户 B 通过 --add 被加入到 owner 列表中,他将拥有该包相关的所有权利。例如发布新版本、yank 一个版本,还能增加和移除 owner,包含添加 B 为 owner 的 A 都可以被移除!
因此,我们必须严肃的指出:不要将你不信任的人添加为 owner ! 免得哪天反目成仇后,他把你移除了 - , -
但是对于 Team 又有所不同,通过 -add 添加的 GitHub Team owner,只拥有受限的权利。它们可以发布或 yank 某个版本,但是他们不能添加或移除 owner!总之,Team 除了可以很方便的管理所有者分组的同时,还能防止一些未知的恶意。
如果大家在添加 team 时遇到问题,可以看看官方的相关文档,由于绝大多数人都无需此功能,因此这里不再详细展开。
构建脚本( Build Scripts)
一些项目希望编译第三方的非 Rust 代码,例如 C 依赖库;一些希望链接本地或者基于源码构建的 C 依赖库;还有一些项目需要功能性的工具,例如在构建之间执行一些代码生成的工作等。
对于这些目标,社区已经提供了一些工具来很好的解决,Cargo 并不想替代它们,但是为了给用户带来一些便利,Cargo 提供了自定义构建脚本的方式,来帮助用户更好的解决类似的问题。
build.rs
若要创建构建脚本,我们只需在项目的根目录下添加一个 build.rs 文件即可。这样一来, Cargo 就会先编译和执行该构建脚本,然后再去构建整个项目。
以下是一个非常简单的脚本示例:
fn main() { // 以下代码告诉 Cargo ,一旦指定的文件 `src/hello.c` 发生了改变,就重新运行当前的构建脚本 println!("cargo:rerun-if-changed=src/hello.c"); // 使用 `cc` 来构建一个 C 文件,然后进行静态链接 cc::Build::new() .file("src/hello.c") .compile("hello"); }
关于构建脚本的一些使用场景如下:
- 构建 C 依赖库
- 在操作系统中寻找指定的 C 依赖库
- 根据某个说明描述文件生成一个 Rust 模块
- 执行一些平台相关的配置
下面的部分我们一起来看看构建脚本具体是如何工作的,然后在下个章节中还提供了一些关于如何编写构建脚本的示例。
Note:
package.build可以用于改变构建脚本的名称,或者直接禁用该功能
构建脚本的生命周期
在项目被构建之前,Cargo 会将构建脚本编译成一个可执行文件,然后运行该文件并执行相应的任务。
在运行的过程中,脚本可以使用之前 println 的方式跟 Cargo 进行通信:通信内容是以 cargo: 开头的格式化字符串。
需要注意的是,Cargo 也不是每次都会重新编译构建脚本,只有当脚本的内容或依赖发生变化时才会。默认情况下,任何文件变化都会触发重新编译,如果你希望对其进行定制,可以使用 rerun-if命令,后文会讲。
在构建脚本成功执行后,我们的项目就会开始进行编译。如果构建脚本的运行过程中发生错误,脚本应该通过返回一个非 0 码来立刻退出,在这种情况下,构建脚本的输出会被打印到终端中。
构建脚本的输入
我们可以通过环境变量的方式给构建脚本提供一些输入值,除此之外,构建脚本所在的当前目录也可以。
构建脚本的输出
构建脚本如果会产出文件,那么这些文件需要放在统一的目录中,该目录可以通过 OUT_DIR 环境变量来指定,构建脚本不应该修改该目录之外的任何文件!
在之前提到过,构建脚本可以通过 println! 输出内容跟 Cargo 进行通信:Cargo 会将每一行带有 cargo: 前缀的输出解析为一条指令,其它的输出内容会自动被忽略。
通过 println! 输出的内容在构建过程中默认是隐藏的,如果大家想要在终端中看到这些内容,你可以使用 -vv 来调用,以下 build.rs :
fn main() { println!("hello, build.rs"); }
将输出:
$ cargo run -vv
[study_cargo 0.1.0] hello, build.rs
构建脚本打印到标准输出 stdout 的所有内容将保存在文件 target/debug/build/<pkg>/output 中 (具体的位置可能取决于你的配置),stderr 的输出内容也将保存在同一个目录中。
以下是 Cargo 能识别的通信指令以及简介,如果大家希望深入了解每个命令,可以点击具体的链接查看官方文档的说明。
cargo:rerun-if-changed=PATH— 当指定路径的文件发生变化时,Cargo 会重新运行脚本cargo:rerun-if-env-changed=VAR— 当指定的环境变量发生变化时,Cargo 会重新运行脚本cargo:rustc-link-arg=FLAG– 将自定义的 flags 传给 linker,用于后续的基准性能测试 benchmark、 可执行文件 binary,、cdylib包、示例和测试cargo:rustc-link-arg-bin=BIN=FLAG– 自定义的 flags 传给 linker,用于可执行文件BINcargo:rustc-link-arg-bins=FLAG– 自定义的 flags 传给 linker,用于可执行文件cargo:rustc-link-arg-tests=FLAG– 自定义的 flags 传给 linker,用于测试cargo:rustc-link-arg-examples=FLAG– 自定义的 flags 传给 linker,用于示例cargo:rustc-link-arg-benches=FLAG– 自定义的 flags 传给 linker,用于基准性能测试 benchmarkcargo:rustc-cdylib-link-arg=FLAG— 自定义的 flags 传给 linker,用于cdylib包cargo:rustc-link-lib=[KIND=]NAME— 告知 Cargo 通过-l去链接一个指定的库,往往用于链接一个本地库,通过 FFIcargo:rustc-link-search=[KIND=]PATH— 告知 Cargo 通过-L将一个目录添加到依赖库的搜索路径中cargo:rustc-flags=FLAGS— 将特定的 flags 传给编译器cargo:rustc-cfg=KEY[="VALUE"]— 开启编译时cfg设置cargo:rustc-env=VAR=VALUE— 设置一个环境变量cargo:warning=MESSAGE— 在终端打印一条 warning 信息cargo:KEY=VALUE—links脚本使用的元数据
构建脚本的依赖
构建脚本也可以引入其它基于 Cargo 的依赖包,只需要在 Cargo.toml 中添加或修改以下内容:
[build-dependencies]
cc = "1.0.46"
需要这么配置的原因在于构建脚本无法使用通过 [dependencies] 或 [dev-dependencies] 引入的依赖包,因为构建脚本的编译运行过程跟项目本身的编译过程是分离的的,且前者先于后者发生。同样的,我们项目也无法使用 [build-dependencies] 中的依赖包。
大家在引入依赖的时候,需要仔细考虑它会给编译时间、开源协议和维护性等方面带来什么样的影响。如果你在 [build-dependencies] 和 [dependencies] 引入了同样的包,这种情况下 Cargo 也许会对依赖进行复用,也许不会,例如在交叉编译时,如果不会,那编译速度自然会受到不小的影响。
links
在 Cargo.toml 中可以配置 package.links 选项,它的目的是告诉 Cargo 当前项目所链接的本地库,同时提供了一种方式可以在项目构建脚本之间传递元信息。
[package]
# ...
links = "foo"
以上配置表明项目链接到一个 libfoo 本地库,当使用 links 时,项目必须拥有一个构建脚本,并且该脚本需要使用 rustc-link-lib 指令来链接目标库。
Cargo 要求一个本地库最多只能被一个项目所链接,换而言之,你无法让两个项目链接到同一个本地库,但是有一种方法可以降低这种限制,感兴趣的同学可以看看官方文档。
假设 A 项目的构建脚本生成任意数量的 kv 形式的元数据,那这些元数据将传递给 A 用作依赖包的项目的构建脚本。例如,如果包 bar 依赖于 foo,当 foo 生成 key=value 形式的构建脚本元数据时,那么 bar 的构建脚本就可以通过环境变量的形式使用该元数据:DEP_FOO_KEY=value。
需要注意的是,该元数据只能传给直接相关者,对于间接的,例如依赖的依赖,就无能为力了。
覆盖构建脚本
当 Cargo.toml 设置了 links 时, Cargo 就允许我们使用自定义库对现有的构建脚本进行覆盖。在 Cargo 使用的配置文件中添加以下内容:
[target.x86_64-unknown-linux-gnu.foo]
rustc-link-lib = ["foo"]
rustc-link-search = ["/path/to/foo"]
rustc-flags = "-L /some/path"
rustc-cfg = ['key="value"']
rustc-env = {key = "value"}
rustc-cdylib-link-arg = ["…"]
metadata_key1 = "value"
metadata_key2 = "value"
增加这个配置后,在未来,一旦我们的某个项目声明了它链接到 foo ,那项目的构建脚本将不会被编译和运行,替代的是这里的配置将被使用。
warning, rerun-if-changed 和 rerun-if-env-changed 这三个 key 在这里不应该被使用,就算用了也会被忽略。
构建脚本示例
下面我们通过一些例子来说明构建脚本该如何使用。社区中也提供了一些构建脚本的常用功能,例如:
- bindgen, 自动生成 Rust -> C 的 FFI 绑定
- cc, 编译 C/C++/汇编
- pkg-config, 使用
pkg-config工具检测系统库 - cmake, 运行
cmake来构建一个本地库 - autocfg, rustc_version, version_check,这些包提供基于
rustc的当前版本来实现条件编译的方法
代码生成
一些项目需要在编译开始前先生成一些代码,下面我们来看看如何在构建脚本中生成一个库调用。
先来看看项目的目录结构:
.
├── Cargo.toml
├── build.rs
└── src
└── main.rs
1 directory, 3 files
Cargo.toml 内容如下:
# Cargo.toml
[package]
name = "hello-from-generated-code"
version = "0.1.0"
接下来,再来看看构建脚本的内容:
// build.rs use std::env; use std::fs; use std::path::Path; fn main() { let out_dir = env::var_os("OUT_DIR").unwrap(); let dest_path = Path::new(&out_dir).join("hello.rs"); fs::write( &dest_path, "pub fn message() -> &'static str { \"Hello, World!\" } " ).unwrap(); println!("cargo:rerun-if-changed=build.rs"); }
以上代码中有几点值得注意:
OUT_DIR环境变量说明了构建脚本的输出目录,也就是最终生成的代码文件的存放地址- 一般来说,构建脚本不应该修改
OUT_DIR之外的任何文件 - 这里的代码很简单,但是我们这是为了演示,大家完全可以生成更复杂、更实用的代码
return-if-changed指令告诉 Cargo 只有在脚本内容发生变化时,才能重新编译和运行构建脚本。如果没有这一行,项目的任何文件发生变化都会导致 Cargo 重新编译运行该构建脚本
下面,我们来看看 main.rs:
// src/main.rs include!(concat!(env!("OUT_DIR"), "/hello.rs")); fn main() { println!("{}", message()); }
这里才是体现真正技术的地方,我们联合使用 rustc 定义的 include! 以及 concat! 和 env! 宏,将生成的代码文件( hello.rs ) 纳入到我们项目的编译流程中。
例子虽然很简单,但是它清晰地告诉了我们该如何生成代码文件以及将这些代码文件纳入到编译中来,大家以后有需要只要回头看看即可。
构建本地库
有时,我们需要在项目中使用基于 C 或 C++ 的本地库,而这种使用场景恰恰是构建脚本非常擅长的。
例如,下面来看看该如何在 Rust 中调用 C 并打印 Hello, World。首先,来看看项目结构和 Cargo.toml:
.
├── Cargo.toml
├── build.rs
└── src
├── hello.c
└── main.rs
1 directory, 4 files
# Cargo.toml
[package]
name = "hello-world-from-c"
version = "0.1.0"
edition = "2021"
现在,我们还不会使用任何构建依赖,先来看看构建脚本:
// build.rs use std::process::Command; use std::env; use std::path::Path; fn main() { let out_dir = env::var("OUT_DIR").unwrap(); Command::new("gcc").args(&["src/hello.c", "-c", "-fPIC", "-o"]) .arg(&format!("{}/hello.o", out_dir)) .status().unwrap(); Command::new("ar").args(&["crus", "libhello.a", "hello.o"]) .current_dir(&Path::new(&out_dir)) .status().unwrap(); println!("cargo:rustc-link-search=native={}", out_dir); println!("cargo:rustc-link-lib=static=hello"); println!("cargo:rerun-if-changed=src/hello.c"); }
首先,构建脚本将我们的 C 文件通过 gcc 编译成目标文件,然后使用 ar 将该文件转换成一个静态库,最后告诉 Cargo 我们的输出内容在 out_dir 中,编译器要在这里搜索相应的静态库,最终通过 -l static-hello 标志将我们的项目跟 libhello.a 进行静态链接。
但是这种硬编码的解决方式有几个问题:
gcc命令的跨平台性是受限的,例如 Windows 下就难以使用它,甚至于有些 Unix 系统也没有gcc命令,同样,ar也有这个问题- 这些命令往往不会考虑交叉编译的问题,如果我们要为 Android 平台进行交叉编译,那么
gcc很可能无法输出一个 ARM 的可执行文件
但是别怕,构建依赖 [build-dependencies] 解君忧:社区中已经有现成的解决方案,可以让这种任务得到更容易的解决。例如文章开头提到的 cc 包。首先在 Cargo.toml 中为构建脚本引入 cc 依赖:
[build-dependencies]
cc = "1.0"
然后重写构建脚本使用 cc :
// build.rs fn main() { cc::Build::new() .file("src/hello.c") .compile("hello"); println!("cargo:rerun-if-changed=src/hello.c"); }
不得不说,Rust 社区的大腿就是粗,代码立刻简洁了很多,最重要的是:可移植性、稳定性等头疼的问题也得到了一并解决。
简单来说,cc 包将构建脚本使用 C 的需求进行了抽象:
cc会针对不同的平台调用合适的编译器:windows 下调用 MSVC, MinGW 下调用 gcc, Unix 平台调用 cc 等- 在编译时会考虑到平台因素,例如将合适的标志传给正在使用的编译器
- 其它环境变量,例如
OPT_LEVEL、DEBUG等会自动帮我们处理 - 标准输出和
OUT_DIR的位置也会被cc所处理
如上所示,与其在每个构建脚本中复制粘贴相同的代码,将尽可能多的功能通过构建依赖来完成是好得多的选择。
再回到例子中,我们来看看 src 下的项目文件:
// src/hello.c
#include <stdio.h>
void hello() {
printf("Hello, World!\n");
}
// src/main.rs // 注意,这里没有再使用 `#[link]` 属性。我们把选择使用哪个 link 的责任交给了构建脚本,而不是在这里进行硬编码 extern { fn hello(); } fn main() { unsafe { hello(); } }
至此,这个简单的例子已经完成,我们学到了该如何使用构建脚本来构建 C 代码,当然又一次被构建脚本和构建依赖的强大所震撼!但控制下情绪,因为构建脚本还能做到更多。
链接系统库
当一个 Rust 包想要链接一个本地系统库时,如何实现平台透明化,就成了一个难题。
例如,我们想使用在 Unix 系统中的 zlib 库,用于数据压缩的目的。实际上,社区中的 libz-sys 包已经这么做了,但是出于演示的目的,我们来看看该如何手动完成,当然,这里只是简化版的,想要看完整代码,见这里。
为了更简单的定位到目标库的位置,可以使用 pkg-config 包,该包使用系统提供的 pkg-config 工具来查询库的信息。它会自动告诉 Cargo 该如何链接到目标库。
先修改 Cargo.toml:
# Cargo.toml
[package]
name = "libz-sys"
version = "0.1.0"
edition = "2021"
links = "z"
[build-dependencies]
pkg-config = "0.3.16"
这里的 links = "z" 用于告诉 Cargo 我们想要链接到 libz 库,在下文还有更多的示例。
构建脚本也很简单:
// build.rs fn main() { pkg_config::Config::new().probe("zlib").unwrap(); println!("cargo:rerun-if-changed=build.rs"); }
下面再在代码中使用:
#![allow(unused)] fn main() { // src/lib.rs use std::os::raw::{c_uint, c_ulong}; extern "C" { pub fn crc32(crc: c_ulong, buf: *const u8, len: c_uint) -> c_ulong; } #[test] fn test_crc32() { let s = "hello"; unsafe { assert_eq!(crc32(0, s.as_ptr(), s.len() as c_uint), 0x3610a686); } } }
代码很清晰,也很简洁,这里就不再过多介绍,运行 cargo build --vv 来看看部分结果( 系统中需要已经安装 libz 库):
[libz-sys 0.1.0] cargo:rustc-link-search=native=/usr/lib
[libz-sys 0.1.0] cargo:rustc-link-lib=z
[libz-sys 0.1.0] cargo:rerun-if-changed=build.rs
非常棒,pkg-config 帮助我们找到了目标库,并且还告知了 Cargo 所有需要的信息!
实际使用中,我们需要做的比上面的代码更多,例如 libz-sys 包会先检查环境变量 LIBZ_SYS_STATIC 或者 static feature,然后基于源码去构建 libz,而不是直接去使用系统库。
使用其它 sys 包
本例中,一起来看看该如何使用 libz-sys 包中的 zlib 来创建一个 C 依赖库。
若你有一个依赖于 zlib 的库,那可以使用 libz-sys 来自动发现或构建该库。这个功能对于交叉编译非常有用,例如 Windows 下往往不会安装 zlib。
libz-sys 通过设置 include 元数据来告知其它包去哪里找到 zlib 的头文件,然后我们的构建脚本可以通过 DEP_Z_INCLUDE 环境变量来读取 include 元数据( 关于元数据的传递,见这里 )。
# Cargo.toml
[package]
name = "zuser"
version = "0.1.0"
edition = "2021"
[dependencies]
libz-sys = "1.0.25"
[build-dependencies]
cc = "1.0.46"
通过包含 libz-sys,确保了最终只会使用一个 libz 库,并且给了我们在构建脚本中使用的途径:
// build.rs fn main() { let mut cfg = cc::Build::new(); cfg.file("src/zuser.c"); if let Some(include) = std::env::var_os("DEP_Z_INCLUDE") { cfg.include(include); } cfg.compile("zuser"); println!("cargo:rerun-if-changed=src/zuser.c"); }
由于 libz-sys 帮我们完成了繁重的相关任务,C 代码只需要包含 zlib 的头文件即可,甚至于它还能在没有安装 zlib 的系统上找到头文件:
// src/zuser.c
#include "zlib.h"
// … 在剩余的代码中使用 zlib
条件编译
构建脚本可以通过发出 rustc-cfg 指令来开启编译时的条件检查。在本例中,一起来看看 openssl 包是如何支持多版本的 OpenSSL 库的。
openssl-sys 包对 OpenSSL 库进行了构建和链接,支持多个不同的实现(例如 LibreSSL )和多个不同的版本。它也使用了 links 配置,这样就可以给其它构建脚本传递所需的信息。例如 version_number ,包含了检测到的 OpenSSL 库的版本号信息。openssl-sys 自己的构建脚本中有类似于如下的代码:
#![allow(unused)] fn main() { println!("cargo:version_number={:x}", openssl_version); }
该指令将 version_number 的信息通过环境变量 DEP_OPENSSL_VERSION_NUMBER 的方式传递给直接使用 openssl-sys 的项目。例如 openssl 包提供了更高级的抽象接口,并且它使用了 openssl-sys 作为依赖。openssl 的构建脚本会通过环境变量读取 openssl-sys 提供的版本号的信息,然后使用该版本号来生成一些 cfg:
#![allow(unused)] fn main() { // (portion of build.rs) if let Ok(version) = env::var("DEP_OPENSSL_VERSION_NUMBER") { let version = u64::from_str_radix(&version, 16).unwrap(); if version >= 0x1_00_01_00_0 { println!("cargo:rustc-cfg=ossl101"); } if version >= 0x1_00_02_00_0 { println!("cargo:rustc-cfg=ossl102"); } if version >= 0x1_01_00_00_0 { println!("cargo:rustc-cfg=ossl110"); } if version >= 0x1_01_00_07_0 { println!("cargo:rustc-cfg=ossl110g"); } if version >= 0x1_01_01_00_0 { println!("cargo:rustc-cfg=ossl111"); } } }
这些 cfg 可以跟 cfg 属性 或 cfg 宏一起使用以实现条件编译。例如,在 OpenSSL 1.1 中引入了 SHA3 的支持,那么我们就可以指定只有当版本号为 1.1 时,才包含并编译相关的代码:
#![allow(unused)] fn main() { // (portion of openssl crate) #[cfg(ossl111)] pub fn sha3_224() -> MessageDigest { unsafe { MessageDigest(ffi::EVP_sha3_224()) } } }
当然,大家在使用时一定要小心,因为这可能会导致生成的二进制文件进一步依赖当前的构建环境。例如,当二进制可执行文件需要在另一个操作系统中分发运行时,那它依赖的信息对于该操作系统可能是不存在的!
日志和监控
这几年 AIOps 特别火,但是你要是逮着一个运维问一下,他估计很难说出个所以然来,毕竟概念和现实往往是脱节的,前者的发展速度肯定远快于后者。
好在我大概了解这块儿领域,可以说智能化运维的核心就在于日志和监控,换而言之?何为智能,不就是基于已有的海量数据分析后进行决策吗?当然,你要说以前的知识库类型的运维决策也是智能,我也没办法杠: D
总之,不仅仅是对于开发者,对于整个技术链条的参与者,甚至包括老板,日志和监控都是开发实践中最最重要的一环。
详解日志
相比起监控,日志好理解的多:在某个时间点向指定的地方输出一条信息,里面记录着重要性、时间、地点和发生的事件,这就是日志。
注意,本文和 Rust 无关,我们争取从一个中立的角度去介绍何为日志
日志级别和输出位置
日志级别
日志级别是对基本的“滚动文本”式日志记录的一个重要补充。每条日志消息都会基于其重要性或严重程度分配到一个日志级别。例如,对于某个程序,“你的电脑着火了”是一个非常重要的消息,而“无法找到配置文件”的重要等级可能就低一些;但对于另外一些程序,"无法找到配置文件" 可能才是最严重的错误,会直接导致程序无法正常启动,而“电脑着火”? 我们可能会记录为一条 Debug 日志(参见下文) :D。
至于到底该如何定义日志级别,这是仁者见仁的事情,并没有一个约定俗成的方式,就连很多大公司,都无法保证自己的开发者严格按照它所制定的规则来输出日志。而下面是我认为的日志级别以及相关定义:
-
Fatal: 程序发生致命错误,祝你好运。这种错误往往来自于程序逻辑的严重异常,例如之前提到的“无法找到配置文件”,再比如无法分配足够的硬盘空间、内存不够用等。遇到这种错误,建议立即退出或者重启程序,然后记录下相应的错误信息
-
Error: 错误,一般指的是程序级别的错误或者严重的业务错误,但这种错误并不会影响程序的运行。一般的用户错误,例如用户名、密码错误等,不使用 Error 级别
-
Warn: 警告,说明这条记录信息需要注意,但是不确定是否发生了错误,因此需要相关的开发来辨别下。或者这条信息既不是错误,但是级别又没有低到 info 级别,就可以用 Warn 来给出警示。例如某条用户连接异常关闭、无法找到相关的配置只能使用默认配置、XX秒后重试等
-
Info: 信息,这种类型的日志往往用于记录程序的运行信息,例如用户操作或者状态的变化,再比如之前的用户名、密码错误,用户请求的开始和结束都可以记录为这个级别
-
Debug: 调试信息,顾名思义是给开发者用的,用于了解程序当前的详细运行状况, 例如用户请求详细信息跟踪、读取到的配置信息、连接握手发包(连接的建立和结束往往是 Info 级别),就可以记录为 Debug 信息
可以看出,日志级别很多,特别是 Debug 日志,如果在生产环境中开启,简直就是一场灾难,每秒几百上千条都很正常。因此我们需要控制日志的最低级别:将最低级别设置为 Info 时,意味着低于 Info 的日志都不会输出,对于上面的分级来说,Debug 日志将不会被输出。
有些开发为了让特定的日志在控制上显示更明显,还会为不同的级别使用不同颜色的文字。
输出位置
通常来说,日志可以输出两个地方:终端控制台和文件。对于前者,我们还有一个称呼标准输出,例如使用 println! 打印到终端的信息就是输出到标准输出中。
如果没有日志持久化的需求,你只是为了调试程序,建议输出到控制台即可。悄悄的说一句,我们还可以为不同的级别设定不同的输出位置,例如 Debug 日志输出到控制台,既方便开发查看,但又不会占用硬盘,而 Info 和 Warning 日志可以输出到文件 info.log 中,至于 Error、Fatal 则可以输出到 error.log 中。
但是如果大家以为只有输出到文件才能持久化日志,那你就错了,在后面的日志采集我们会详细介绍,先来看看日志查看。
日志查看
关于如何查看日志,相信大家都非常熟悉了,常用的方式有三种(事实上,可能也只有这三种):
- 在控制台查看,即可以直接查看输出到标准输出的日志,还可以使用 tail、cat、grep 等命令从日志文件中搜索查询或者以实时滚动的方式查看最新的日志
- 最简单的,进入到日志文件中,进行字符串搜索,或者从头到尾、从尾到头进行逐行查看
- 在可视化界面上查看,但是这个往往要配合日志采集工具,将日志采集到 ElasticSearch 或者其它搜索平台、数据中,然后再通过 kibana、grafana 等图形化服务进行搜索、查看,最重要的是可以进行日志的聚合统计,例如可以很方便的在 kibana 中查询满足指定条件的日志在某段时间内出现了多少次。
大家现在知道了,可视化,首先需要将日志集中采集起来,那么该如何采集日志呢?
日志采集
之前我们提到,不是只有输出到文件才能持久化日志,事实上,输出到控制台也能持久化日志。
其中的秘诀就在于使用一个日志采集工具去从控制台的标准输出读取日志数据,然后将读取到的数据发送到日志存储平台,例如 ElasticSearch,进行集中存储。当然,在存储前,还需要进行日志格式、数据的处理,以便只保留我们需要的格式和日志数据。
最典型的就是容器或容器云环境的日志采集,基本都是通过上面的方式进行的:容器中的进程将日志输出到标准输出,然后一个单独的日志采集服务直接读取标准输出中的日志,再通过网络发送到日志处理、存储的平台。大家发现了吗?这个流程完全不会在应用运行的本地或宿主机上存储任何日志,所以特别适合容器环境!
目前常用的日志采集工具有 filebeat、vector( Rust 开发,功能强大,性能非常高 ) 等,它们都是以 agent 的形式运行在你的应用程序旁边( 在同一个 pod 或虚拟机上 ),提供贴心的服务。
中心化日志存储
最后,我们再来简单介绍下日志存储。提到存储,首先不得不提的就是日志使用方式。
其实,除了 Debug 的时候,我们使用日志基本都是基于某个关键字进行搜索的,将日志存储在各台主机上的硬盘文件中,然后逐个去查询显然是非常非常低效的,最好的方式就是将日志集中收集上来后,存储在一个搜索平台中,例如 ElasticSearch。
当然,存储的时候肯定也不是简单的一行一行存储,而是需要将一条日志的多个关键词切取出来,然后以关键词索引的方式进行存储( 简化模型 ),这样我们就可以在后续使用时,通过关键词来搜索日志了。
日志门面 log
就如同 slf4j 是 Java 的日志门面库,log 也是 Rust 的日志门面库( 这不是我自己编的,官方用语: logging facade ),它目前由官方积极维护,因此大家可以放心使用。
使用方式很简单,只要在 Cargo.toml 中引入即可:
[dependencies]
log = "0.4"
日志门面不是说排场很大的意思,而是指相应的日志 API 已成为事实上的标准,会被其它日志框架所使用。通过这种统一的门面,开发者就可以不必再拘泥于日志框架的选择,未来大不了再换一个日志框架就是
既然是门面,log 自然定义了一套统一的日志特征和 API,将日志的操作进行了抽象。
Log 特征
例如,它定义了一个 Log 特征:
#![allow(unused)] fn main() { pub trait Log: Sync + Send { fn enabled(&self, metadata: &Metadata<'_>) -> bool; fn log(&self, record: &Record<'_>); fn flush(&self); } }
enabled用于判断某条带有元数据的日志是否能被记录,它对于log_enabled!宏特别有用log会记录record所代表的日志flush会将缓存中的日志数据刷到输出中,例如标准输出或者文件中
日志宏
log 还为我们提供了一整套标准的宏,用于方便地记录日志。看到 trace!、debug!、info!、warn!、error!,大家是否感觉眼熟呢?是的,它们跟上一章节提到的日志级别几乎一模一样,唯一的区别就是这里乱入了一个 trace!,它比 debug! 的日志级别还要低,记录的信息还要详细。可以说,你如果想巨细无遗地了解某个流程的所有踪迹,它就是不二之选。
#![allow(unused)] fn main() { use log::{info, trace, warn}; pub fn shave_the_yak(yak: &mut Yak) { trace!("Commencing yak shaving"); loop { match find_a_razor() { Ok(razor) => { info!("Razor located: {}", razor); yak.shave(razor); break; } Err(err) => { warn!("Unable to locate a razor: {}, retrying", err); } } } } }
上面的例子使用 trace! 记录了一条可有可无的信息:准备开始剃须,然后开始寻找剃须刀,找到后就用 info! 记录一条可能事后也没人看的信息:找到剃须刀;没找到的话,就记录一条 warn! 信息,这条信息就有一定价值了,不仅告诉我们没找到的原因,还记录了发生的次数,有助于事后定位问题。
可以看出,这里使用日志级别的方式和我们上一章节所述基本相符。
除了以上常用的,log 还提供了 log! 和 log_enabled! 宏,后者用于确定一条消息在当前模块中,对于给定的日志级别是否能够被记录
#![allow(unused)] fn main() { use log::Level::Debug; use log::{debug, log_enabled}; // 判断能否记录 Debug 消息 if log_enabled!(Debug) { let data = expensive_call(); // 下面的日志记录较为昂贵,因此我们先在前面判断了是否能够记录,能,才继续这里的逻辑 debug!("expensive debug data: {} {}", data.x, data.y); } if log_enabled!(target: "Global", Debug) { let data = expensive_call(); debug!(target: "Global", "expensive debug data: {} {}", data.x, data.y); } }
而 log! 宏就简单的多,它是一个通用的日志记录方式,因此需要我们手动指定日志级别:
#![allow(unused)] fn main() { use log::{log, Level}; let data = (42, "Forty-two"); let private_data = "private"; log!(Level::Error, "Received errors: {}, {}", data.0, data.1); log!(target: "app_events", Level::Warn, "App warning: {}, {}, {}", data.0, data.1, private_data); }
日志输出在哪里?
我不知道有没有同学尝试运行过上面的代码,但是我知道,就算你们运行了,也看不到任何输出。
为什么?原因很简单,log 仅仅是日志门面库,它并不具备完整的日志库功能!,因此你无法在控制台中看到任何日志输出,这种情况下,说实话,远不如一个 println! 有用!
但是别急,让我们看看该如何让 log 有用起来。
使用具体的日志库
log 包这么设计,其实是有很多好处的。
Rust 库的开发者
最直接的好处就是,如果你是一个 Rust 库开发者,那你自己或库的用户肯定都不希望这个库绑定任何具体的日志库,否则用户想使用 log1 来记录日志,你的库却使用了 log2,这就存在很多问题了!
因此,作为库的开发者,你只要在库中使用门面库即可,将具体的日志库交给用户去选择和绑定。
#![allow(unused)] fn main() { use log::{info, trace, warn}; pub fn deal_with_something() { // 开始处理 // 记录一些日志 trace!("a trace log"); info!("a info long: {}", "abc"); warn!("a warning log: {}, retrying", err); // 结束处理 } }
应用开发者
如果是应用开发者,那你的应用运行起来,却看不到任何日志输出,这种场景想想都捉急。此时就需要去选择一个具体的日志库了。
目前来说,已经有了不少日志库实现,官方也推荐了一些 ,大家可以根据自己的需求来选择,不过 env_logger 是一个相当不错的选择。
log 还提供了 set_logger 函数用于设置日志库,set_max_level 用于设置最大日志级别,但是如果你选了具体的日志库,它往往会提供更高级的 API,无需我们手动调用这两个函数,例如下面的 env_logger 就是如此。
env_logger
修改 Cargo.toml , 添加以下内容:
# in Cargo.toml
[dependencies]
log = "0.4.0"
env_logger = "0.9"
在 src/main.rs 中添加如下代码:
use log::{debug, error, log_enabled, info, Level}; fn main() { // 注意,env_logger 必须尽可能早的初始化 env_logger::init(); debug!("this is a debug {}", "message"); error!("this is printed by default"); if log_enabled!(Level::Info) { let x = 3 * 4; // expensive computation info!("the answer was: {}", x); } }
在运行程序时,可以通过环境变量来设定日志级别:
$ RUST_LOG=error ./main
[2017-11-09T02:12:24Z ERROR main] this is printed by default
我们还可以为单独一个模块指定日志级别:
$ RUST_LOG=main=info ./main
[2017-11-09T02:12:24Z ERROR main] this is printed by default
[2017-11-09T02:12:24Z INFO main] the answer was: 12
还能为某个模块开启所有日志级别:
$ RUST_LOG=main ./main
[2017-11-09T02:12:24Z DEBUG main] this is a debug message
[2017-11-09T02:12:24Z ERROR main] this is printed by default
[2017-11-09T02:12:24Z INFO main] the answer was: 12
需要注意的是,如果文件名包含 -,你需要将其替换成下划线来使用,原因是 Rust 的模块和包名不支持使用 -。
$ RUST_LOG=my_app ./my-app
[2017-11-09T02:12:24Z DEBUG my_app] this is a debug message
[2017-11-09T02:12:24Z ERROR my_app] this is printed by default
[2017-11-09T02:12:24Z INFO my_app] the answer was: 12
默认情况下,env_logger 会输出到标准错误 stderr,如果你想要输出到标准输出 stdout,可以使用 Builder 来改变日志对象( target ):
#![allow(unused)] fn main() { use std::env; use env_logger::{Builder, Target}; let mut builder = Builder::from_default_env(); builder.target(Target::Stdout); builder.init(); }
默认
#![allow(unused)] fn main() { if cfg!(debug_assertions) { eprintln!("debug: {:?} -> {:?}", record, fields); } }
日志库开发者
对于这类开发者而言,自然要实现自己的 Log 特征咯:
#![allow(unused)] fn main() { use log::{Record, Level, Metadata}; struct SimpleLogger; impl log::Log for SimpleLogger { fn enabled(&self, metadata: &Metadata) -> bool { metadata.level() <= Level::Info } fn log(&self, record: &Record) { if self.enabled(record.metadata()) { println!("{} - {}", record.level(), record.args()); } } fn flush(&self) {} } }
除此之外,我们还需要像 env_logger 一样包装下 set_logger 和 set_max_level:
#![allow(unused)] fn main() { use log::{SetLoggerError, LevelFilter}; static LOGGER: SimpleLogger = SimpleLogger; pub fn init() -> Result<(), SetLoggerError> { log::set_logger(&LOGGER) .map(|()| log::set_max_level(LevelFilter::Info)) } }
更多示例
关于 log 门面库和具体的日志库还有更多的使用方式,详情请参见锈书的开发者工具一章。
使用 tracing 记录日志
严格来说,tracing 并不是一个日志库,而是一个分布式跟踪的 SDK,用于采集监控数据的。
随着微服务的流行,现在一个产品有多个系统组成是非常常见的,这种情况下,一条用户请求可能会横跨几个甚至几十个服务。此时再用传统的日志方式去跟踪这条用户请求就变得较为困难,这就是分布式追踪在现代化监控系统中这么炽手可热的原因。
关于分布式追踪,在后面的监控章节进行详细介绍,大家只要知道:分布式追踪的核心就是在请求的开始生成一个 trace_id,然后将该 trace_id 一直往后透穿,请求经过的每个服务都会使用该 trace_id 记录相关信息,最终将整个请求形成一个完整的链路予以记录下来。
那么后面当要查询这次请求的相关信息时,只要使用 trace_id 就可以获取整个请求链路的所有信息了,非常简单好用。看到这里,相信大家也明白为什么这个库的名称叫 tracing 了吧?
至于为何把它归到日志库的范畴呢?因为 tracing 支持 log 门面库的 API,因此,它既可以作为分布式追踪的 SDK 来使用,也可以作为日志库来使用。
在分布式追踪中,trace_id 都是由 SDK 自动生成和往后透穿,对于用户的使用来说是完全透明的。如果你要手动用日志的方式来实现请求链路的追踪,那么就必须考虑 trace_id 的手动生成、透传,以及不同语言之间的协议规范等问题
一个简单例子
开始之前,需要先将 tracing 添加到项目的 Cargo.toml 中:
[dependencies]
tracing = "0.1"
注意,在写作本文时,0.2 版本已经快要出来了,所以具体使用的版本请大家以阅读时为准。
下面的例子中将同时使用 log 和 tracing :
use log; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt}; fn main() { // 只有注册 subscriber 后, 才能在控制台上看到日志输出 tracing_subscriber::registry() .with(fmt::layer()) .init(); // 调用 `log` 包的 `info!` log::info!("Hello world"); let foo = 42; // 调用 `tracing` 包的 `info!` tracing::info!(foo, "Hello from tracing"); }
可以看出,门面库的排场还是有的,tracing 在 API 上明显是使用了 log 的规范。
运行后,输出如下日志:
2022-04-09T14:34:28.965952Z INFO test_tracing: Hello world
2022-04-09T14:34:28.966011Z INFO test_tracing: Hello from tracing foo=42
还可以看出,log 的日志格式跟 tracing 一模一样,结合上一章节的知识,相信聪明的同学已经明白了这是为什么。
那么 tracing 跟 log 的具体日志实现框架有何区别呢?别急,我们再来接着看。
异步编程中的挑战
除了分布式追踪,在异步编程中使用传统的日志也是存在一些问题的,最大的挑战就在于异步任务的执行没有确定的顺序,那么输出的日志也将没有确定的顺序并混在一起,无法按照我们想要的逻辑顺序串联起来。
归根到底,在于日志只能针对某个时间点进行记录,缺乏上下文信息,而线程间的执行顺序又是不确定的,因此日志就有些无能为力。而 tracing 为了解决这个问题,引入了 span 的概念( 这个概念也来自于分布式追踪 ),一个 span 代表了一个时间段,拥有开始和结束时间,在此期间的所有类型数据、结构化数据、文本数据都可以记录其中。
大家发现了吗? span 是可以拥有上下文信息的,这样就能帮我们把信息按照所需的逻辑性串联起来了。
核心概念
tracing 中最重要的三个概念是 Span、Event 和 Collector,下面我们来一一简单介绍下。
Span
相比起日志只能记录在某个时间点发生的事件,span 最大的意义就在于它可以记录一个过程,也就是在某一段时间内发生的事件流。既然是记录时间段,那自然有开始和结束:
use tracing::{span, Level}; fn main() { let span = span!(Level::TRACE, "my_span"); // `enter` 返回一个 RAII ,当其被 drop 时,将自动结束该 span let enter = span.enter(); // 这里开始进入 `my_span` 的上下文 // 下面执行一些任务,并记录一些信息到 `my_span` 中 // ... } // 这里 enter 将被 drop,`my_span` 也随之结束
Event 事件
Event 代表了某个时间点发生的事件,这方面它跟日志类似,但是不同的是,Event 还可以产生在 span 的上下文中。
use tracing::{event, span, Level}; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt}; fn main() { tracing_subscriber::registry().with(fmt::layer()).init(); // 在 span 的上下文之外记录一次 event 事件 event!(Level::INFO, "something happened"); let span = span!(Level::INFO, "my_span"); let _guard = span.enter(); // 在 "my_span" 的上下文中记录一次 event event!(Level::DEBUG, "something happened inside my_span"); }
2022-04-09T14:51:38.382987Z INFO test_tracing: something happened
2022-04-09T14:51:38.383111Z DEBUG my_span: test_tracing: something happened inside my_span
虽然 event 在哪里都可以使用,但是最好只在 span 的上下文中使用:用于代表一个时间点发生的事件,例如记录 HTTP 请求返回的状态码,从队列中获取一个对象,等等。
Collector 收集器
当 Span 或 Event 发生时,它们会被实现了 Collect 特征的收集器所记录或聚合。这个过程是通过通知的方式实现的:当 Event 发生或者 Span 开始/结束时,会调用 Collect 特征的相应方法通知 Collector。
tracing-subscriber
我们前面提到只有使用了 tracing-subscriber 后,日志才能输出到控制台中。
之前大家可能还不理解,现在应该明白了,它是一个 Collector,可以将记录的日志收集后,再输出到控制台中。
使用方法
span! 宏
span! 宏可以用于创建一个 Span 结构体,然后通过调用结构体的 enter 方法来开始,再通过超出作用域时的 drop 来结束。
use tracing::{span, Level}; fn main() { let span = span!(Level::TRACE, "my_span"); // `enter` 返回一个 RAII ,当其被 drop 时,将自动结束该 span let enter = span.enter(); // 这里开始进入 `my_span` 的上下文 // 下面执行一些任务,并记录一些信息到 `my_span` 中 // ... } // 这里 enter 将被 drop,`my_span` 也随之结束
#[instrument]
如果想要将某个函数的整个函数体都设置为 span 的范围,最简单的方法就是为函数标记上 #[instrument],此时 tracing 会自动为函数创建一个 span,span 名跟函数名相同,在输出的信息中还会自动带上函数参数。
use tracing::{info, instrument}; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt}; #[instrument] fn foo(ans: i32) { info!("in foo"); } fn main() { tracing_subscriber::registry().with(fmt::layer()).init(); foo(42); }
2022-04-10T02:44:12.885556Z INFO foo{ans=42}: test_tracing: in foo
关于 #[instrument] 详细说明,请参见官方文档。
in_scope
对于没有内置 tracing 支持或者无法使用 #instrument 的函数,例如外部库的函数,我们可以使用 Span 结构体的 in_scope 方法,它可以将同步代码包裹在一个 span 中:
#![allow(unused)] fn main() { use tracing::info_span; let json = info_span!("json.parse").in_scope(|| serde_json::from_slice(&buf))?; }
在 async 中使用 span
需要注意,如果是在异步编程时使用,要避免以下使用方式:
#![allow(unused)] fn main() { async fn my_async_function() { let span = info_span!("my_async_function"); // WARNING: 该 span 直到 drop 后才结束,因此在 .await 期间,span 依然处于工作中状态 let _enter = span.enter(); // 在这里 span 依然在记录,但是 .await 会让出当前任务的执行权,然后运行时会去运行其它任务,此时这个 span 可能会记录其它任务的执行信息,最终记录了不正确的 trace 信息 some_other_async_function().await // ... } }
我们建议使用以下方式,简单又有效:
#![allow(unused)] fn main() { use tracing::{info, instrument}; use tokio::{io::AsyncWriteExt, net::TcpStream}; use std::io; #[instrument] async fn write(stream: &mut TcpStream) -> io::Result<usize> { let result = stream.write(b"hello world\n").await; info!("wrote to stream; success={:?}", result.is_ok()); result } }
那有同学可能要问了,是不是我们无法在异步代码中使用 span.enter 了,答案是:是也不是。
是,你无法直接使用 span.enter 语法了,原因上面也说过,但是可以通过下面的方式来曲线使用:
#![allow(unused)] fn main() { use tracing::Instrument; let my_future = async { // ... }; my_future .instrument(tracing::info_span!("my_future")) .await }
span 嵌套
tracing 的 span 不仅仅是上面展示的基本用法,它们还可以进行嵌套!
use tracing::{debug, info, span, Level}; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt}; fn main() { tracing_subscriber::registry().with(fmt::layer()).init(); let scope = span!(Level::DEBUG, "foo"); let _enter = scope.enter(); info!("Hello in foo scope"); debug!("before entering bar scope"); { let scope = span!(Level::DEBUG, "bar", ans = 42); let _enter = scope.enter(); debug!("enter bar scope"); info!("In bar scope"); debug!("end bar scope"); } debug!("end bar scope"); }
INFO foo: log_test: Hello in foo scope
DEBUG foo: log_test: before entering bar scope
DEBUG foo:bar{ans=42}: log_test: enter bar scope
INFO foo:bar{ans=42}: log_test: In bar scope
DEBUG foo:bar{ans=42}: log_test: end bar scope
DEBUG foo: log_test: end bar scope
在上面的日志中,foo:bar 不仅包含了 foo 和 bar span 名,还显示了它们之间的嵌套关系。
对宏进行配置
日志级别和目标
span! 和 event! 宏都需要设定相应的日志级别,而且它们支持可选的 target 或 parent 参数( 只能二者选其一 ),该参数用于描述事件发生的位置,如果父 span 没有设置,target 参数也没有提供,那这个位置默认分别是当前的 span 和 当前的模块。
use tracing::{debug, info, span, Level,event}; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt}; fn main() { tracing_subscriber::registry().with(fmt::layer()).init(); let s = span!(Level::TRACE, "my span"); // 没进入 span,因此输出日志将不会带上 span 的信息 event!(target: "app_events", Level::INFO, "something has happened 1!"); // 进入 span ( 开始 ) let _enter = s.enter(); // 没有设置 target 和 parent // 这里的对象位置分别是当前的 span 名和模块名 event!(Level::INFO, "something has happened 2!"); // 设置了 target // 这里的对象位置分别是当前的 span 名和 target event!(target: "app_events",Level::INFO, "something has happened 3!"); let span = span!(Level::TRACE, "my span 1"); // 这里就更为复杂一些,留给大家作为思考题 event!(parent: &span, Level::INFO, "something has happened 4!"); }
记录字段
我们可以通过语法 field_name = field_value 来输出结构化的日志
#![allow(unused)] fn main() { // 记录一个事件,带有两个字段: // - "answer", 值是 42 // - "question", 值是 "life, the universe and everything" event!(Level::INFO, answer = 42, question = "life, the universe, and everything"); // 日志输出 -> INFO test_tracing: answer=42 question="life, the universe, and everything" }
捕获环境变量
还可以捕获环境中的变量:
#![allow(unused)] fn main() { let user = "ferris"; // 下面的简写方式 span!(Level::TRACE, "login", user); // 等价于: span!(Level::TRACE, "login", user = user); }
use tracing::{info, span, Level}; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt}; fn main() { tracing_subscriber::registry().with(fmt::layer()).init(); let user = "ferris"; let s = span!(Level::TRACE, "login", user); let _enter = s.enter(); info!(welcome="hello", user); // 下面一行将报错,原因是这种写法是格式化字符串的方式,必须使用 info!("hello {}", user) // info!("hello", user); } // 日志输出 -> INFO login{user="ferris"}: test_tracing: welcome="hello" user="ferris"
字段名的多种形式
字段名还可以包含 . :
#![allow(unused)] fn main() { let user = "ferris"; let email = "ferris@rust-lang.org"; event!(Level::TRACE, user, user.email = email); // 还可以使用结构体 let user = User { name: "ferris", email: "ferris@rust-lang.org", }; // 直接访问结构体字段,无需赋值即可使用 span!(Level::TRACE, "login", user.name, user.email); // 字段名还可以使用字符串 event!(Level::TRACE, "guid:x-request-id" = "abcdef", "type" = "request"); // 日志输出 -> // TRACE test_tracing: user="ferris" user.email="ferris@rust-lang.org" // TRACE test_tracing: user.name="ferris" user.email="ferris@rust-lang.org" // TRACE test_tracing: guid:x-request-id="abcdef" type="request" }
?
? 符号用于说明该字段将使用 fmt::Debug 来格式化。
#![allow(unused)] fn main() { #[derive(Debug)] struct MyStruct { field: &'static str, } let my_struct = MyStruct { field: "Hello world!", }; // `my_struct` 将使用 Debug 的形式输出 event!(Level::TRACE, greeting = ?my_struct); // 等价于: event!(Level::TRACE, greeting = tracing::field::debug(&my_struct)); // 下面代码将报错, my_struct 没有实现 Display // event!(Level::TRACE, greeting = my_struct); // 日志输出 -> TRACE test_tracing: greeting=MyStruct { field: "Hello world!" } }
%
% 说明字段将用 fmt::Display 来格式化。
#![allow(unused)] fn main() { // `my_struct.field` 将使用 `fmt::Display` 的格式化形式输出 event!(Level::TRACE, greeting = %my_struct.field); // 等价于: event!(Level::TRACE, greeting = tracing::field::display(&my_struct.field)); // 作为对比,大家可以看下 Debug 和正常的字段输出长什么样 event!(Level::TRACE, greeting = ?my_struct.field); event!(Level::TRACE, greeting = my_struct.field); // 下面代码将报错, my_struct 没有实现 Display // event!(Level::TRACE, greeting = %my_struct); }
2022-04-10T03:49:00.834330Z TRACE test_tracing: greeting=Hello world!
2022-04-10T03:49:00.834410Z TRACE test_tracing: greeting=Hello world!
2022-04-10T03:49:00.834422Z TRACE test_tracing: greeting="Hello world!"
2022-04-10T03:49:00.834433Z TRACE test_tracing: greeting="Hello world!"
Empty
字段还能标记为 Empty,用于说明该字段目前没有任何值,但是可以在后面进行记录。
#![allow(unused)] fn main() { use tracing::{trace_span, field}; let span = trace_span!("my_span", greeting = "hello world", parting = field::Empty); // ... // 现在,为 parting 记录一个值 span.record("parting", &"goodbye world!"); }
格式化字符串
除了以字段的方式记录信息,我们还可以使用格式化字符串的方式( 同 println! 、format! )。
注意,当字段跟格式化的方式混用时,必须把格式化放在最后,如下所示
#![allow(unused)] fn main() { let question = "the ultimate question of life, the universe, and everything"; let answer = 42; event!( Level::DEBUG, question.answer = answer, question.tricky = true, "the answer to {} is {}.", question, answer ); // 日志输出 -> DEBUG test_tracing: the answer to the ultimate question of life, the universe, and everything is 42. question.answer=42 question.tricky=true }
文件输出
截至目前,我们上面的日志都是输出到控制台中。
针对文件输出,tracing 提供了一个专门的库 tracing-appender,大家可以查看官方文档了解更多。
一个综合例子
最后,再来看一个综合的例子,使用了 color-eyre 和 文件输出,前者用于为输出的日志加上更易读的颜色。
use color_eyre::{eyre::eyre, Result}; use tracing::{error, info, instrument}; use tracing_appender::{non_blocking, rolling}; use tracing_error::ErrorLayer; use tracing_subscriber::{ filter::EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt, Registry, }; #[instrument] fn return_err() -> Result<()> { Err(eyre!("Something went wrong")) } #[instrument] fn call_return_err() { info!("going to log error"); if let Err(err) = return_err() { // 推荐大家运行下,看看这里的输出效果 error!(?err, "error"); } } fn main() -> Result<()> { let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); // 输出到控制台中 let formatting_layer = fmt::layer().pretty().with_writer(std::io::stderr); // 输出到文件中 let file_appender = rolling::never("logs", "app.log"); let (non_blocking_appender, _guard) = non_blocking(file_appender); let file_layer = fmt::layer() .with_ansi(false) .with_writer(non_blocking_appender); // 注册 Registry::default() .with(env_filter) // ErrorLayer 可以让 color-eyre 获取到 span 的信息 .with(ErrorLayer::default()) .with(formatting_layer) .with(file_layer) .init(); // 安裝 color-eyre 的 panic 处理句柄 color_eyre::install()?; call_return_err(); Ok(()) }
总结 & 推荐
至此,tracing 的介绍就已结束,相信大家都看得出,它比上个章节的 log 及兄弟们要更加复杂一些,一方面是因为它能更好的支持异步编程环境,另一方面就是它还是一个分布式追踪的库,对于后者,我们将在后续的监控章节进行讲解。
如果你让我推荐使用哪个,那我的建议是:
- 对于简单的工程,例如用于 POC( Proof of Concepts ) 目的,使用
log即可 - 对于需要认真对待,例如生产级或优秀的开源项目,建议使用 tracing 的方式,一举解决日志和监控的后顾之忧
使用 tracing 输出自定义的 Rust 日志
在 tracing 包出来前,Rust 的日志也就 log 有一战之力,但是 log 的功能相对来说还是简单一些。在大名鼎鼎的 tokio 开发团队推出 tracing 后,我现在坚定的认为 tracing 就是未来!
截至目前,rust编译器团队、GraphQL 都在使用 tracing,而且 tokio 在密谋一件大事:基于 tracing 开发一套终端交互式 debug 工具: console!
基于这种坚定的信仰,我们决定将公司之前使用的 log 包替换成 tracing ,但是有一个问题:后者提供的 JSON logger 总感觉不是那个味儿。这意味着,对于程序员来说,最快乐的时光又要到来了:定制自己的开发工具。
好了,闲话少说,下面我们一起来看看该如何构建自己的 logger,以及深入了解 tracing 的一些原理,当然你也可以只选择来凑个热闹,总之,开始吧!
打地基(1)
首先,使用 cargo new --bin test-tracing 创建一个新的二进制类型( binary )的项目。
然后引入以下依赖:
# in cargo.toml
[dependencies]
serde_json = "1"
tracing = "0.1"
tracing-subscriber = "0.3"
其中 tracing-subscriber 用于订阅正在发生的日志、监控事件,然后可以对它们进行进一步的处理。serde_json 可以帮我们更好的处理格式化的 JSON,毕竟咱们要解决的问题就来自于 JSON logger。
下面来实现一个基本功能:设置自定义的 logger,并使用 info! 来打印一行日志。
// in examples/figure_0/main.rs use tracing::info; use tracing_subscriber::prelude::*; mod custom_layer; use custom_layer::CustomLayer; fn main() { // 设置 `tracing-subscriber` 对 tracing 数据的处理方式 tracing_subscriber::registry().with(CustomLayer).init(); // 打印一条简单的日志。用 `tracing` 的行话来说,`info!` 将创建一个事件 info!(a_bool = true, answer = 42, message = "first example"); }
大家会发现,上面引入了一个模块 custom_layer, 下面从该模块开始,来实现我们的自定义 logger。首先,tracing-subscriber 提供了一个特征 Layer 专门用于处理 tracing 的各种事件( span, event )。
#![allow(unused)] fn main() { // in examples/figure_0/custom_layer.rs use tracing_subscriber::Layer; pub struct CustomLayer; impl<S> Layer<S> for CustomLayer where S: tracing::Subscriber {} }
由于还没有填入任何代码,运行该示例比你打的水漂还无力 - 毫无效果。
捕获事件
在 tracing 中,当 info!、error! 等日志宏被调用时,就会产生一个相应的事件 Event。
而我们首先,就要为之前的 Layer 特征实现 on_event 方法。
// in examples/figure_0/custom_layer.rs where S: tracing::Subscriber, { fn on_event( &self, event: &tracing::Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>, ) { println!("Got event!"); println!(" level={:?}", event.metadata().level()); println!(" target={:?}", event.metadata().target()); println!(" name={:?}", event.metadata().name()); for field in event.fields() { println!(" field={}", field.name()); } } }
从代码中可以看出,我们打印了事件中包含的事件名、日志等级以及事件发生的代码路径。运行后,可以看到以下输出:
$ cargo run --example figure_1
Got event!
level=Level(Info)
target="figure_1"
name="event examples/figure_1/main.rs:10"
field=a_bool
field=answer
field=message
但是奇怪的是,我们无法通过 API 来获取到具体的 field 值。还有就是,上面的输出还不是 JSON 格式。
现在问题来了,要创建自己的 logger,不能获取 filed 显然是不靠谱的。
访问者模式
在设计上,tracing 作出了一个选择:永远不会自动存储产生的事件数据( spans, events )。如果我们要获取这些数据,就必须自己手动存储。
解决办法就是使用访问者模式(Visitor Pattern):手动实现 Visit 特征去获取事件中的值。Visit 为每个 tracing 可以处理的类型都提供了对应的 record_X 方法。
#![allow(unused)] fn main() { // in examples/figure_2/custom_layer.rs struct PrintlnVisitor; impl tracing::field::Visit for PrintlnVisitor { fn record_f64(&mut self, field: &tracing::field::Field, value: f64) { println!(" field={} value={}", field.name(), value) } fn record_i64(&mut self, field: &tracing::field::Field, value: i64) { println!(" field={} value={}", field.name(), value) } fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { println!(" field={} value={}", field.name(), value) } fn record_bool(&mut self, field: &tracing::field::Field, value: bool) { println!(" field={} value={}", field.name(), value) } fn record_str(&mut self, field: &tracing::field::Field, value: &str) { println!(" field={} value={}", field.name(), value) } fn record_error( &mut self, field: &tracing::field::Field, value: &(dyn std::error::Error + 'static), ) { println!(" field={} value={}", field.name(), value) } fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { println!(" field={} value={:?}", field.name(), value) } } }
然后在之前的 on_event 中来使用这个新的访问者: event.record(&mut visitor) 可以访问其中的所有值。
#![allow(unused)] fn main() { // in examples/figure_2/custom_layer.rs fn on_event( &self, event: &tracing::Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>, ) { println!("Got event!"); println!(" level={:?}", event.metadata().level()); println!(" target={:?}", event.metadata().target()); println!(" name={:?}", event.metadata().name()); let mut visitor = PrintlnVisitor; event.record(&mut visitor); } }
这段代码看起来有模有样,来运行下试试:
$ cargo run --example figure_2
Got event!
level=Level(Info)
target="figure_2"
name="event examples/figure_2/main.rs:10"
field=a_bool value=true
field=answer value=42
field=message value=first example
Bingo ! 一切完美运行 !
构建 JSON logger
目前为止,离我们想要的 JSON logger 只差一步了。下面来实现一个 JsonVisitor 替代之前的 PrintlnVisitor 用于构建一个 JSON 对象。
#![allow(unused)] fn main() { // in examples/figure_3/custom_layer.rs impl<'a> tracing::field::Visit for JsonVisitor<'a> { fn record_f64(&mut self, field: &tracing::field::Field, value: f64) { self.0 .insert(field.name().to_string(), serde_json::json!(value)); } fn record_i64(&mut self, field: &tracing::field::Field, value: i64) { self.0 .insert(field.name().to_string(), serde_json::json!(value)); } fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { self.0 .insert(field.name().to_string(), serde_json::json!(value)); } fn record_bool(&mut self, field: &tracing::field::Field, value: bool) { self.0 .insert(field.name().to_string(), serde_json::json!(value)); } fn record_str(&mut self, field: &tracing::field::Field, value: &str) { self.0 .insert(field.name().to_string(), serde_json::json!(value)); } fn record_error( &mut self, field: &tracing::field::Field, value: &(dyn std::error::Error + 'static), ) { self.0.insert( field.name().to_string(), serde_json::json!(value.to_string()), ); } fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { self.0.insert( field.name().to_string(), serde_json::json!(format!("{:?}", value)), ); } } }
#![allow(unused)] fn main() { // in examples/figure_3/custom_layer.rs fn on_event( &self, event: &tracing::Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>, ) { // Covert the values into a JSON object let mut fields = BTreeMap::new(); let mut visitor = JsonVisitor(&mut fields); event.record(&mut visitor); // Output the event in JSON let output = serde_json::json!({ "target": event.metadata().target(), "name": event.metadata().name(), "level": format!("{:?}", event.metadata().level()), "fields": fields, }); println!("{}", serde_json::to_string_pretty(&output).unwrap()); } }
继续运行:
$ cargo run --example figure_3
{
"fields": {
"a_bool": true,
"answer": 42,
"message": "first example"
},
"level": "Level(Info)",
"name": "event examples/figure_3/main.rs:10",
"target": "figure_3"
}
终于,我们实现了自己的 logger,并且成功地输出了一条 JSON 格式的日志。并且新实现的 Layer 就可以添加到 tracing-subscriber 中用于记录日志事件。
下面再来一起看看如何使用tracing 提供的 period-of-time spans 为日志增加更详细的上下文信息。
何为 span
在之前我们多次提到 span 这个词,但是何为 span?
不知道大家知道分布式追踪不?在分布式系统中每一个请求从开始到返回,会经过多个服务,这条请求路径被称为请求跟踪链路( trace ),可以看出,一条链路是由多个部分组成,我们可以简单的把其中一个部分认为是一个 span。
跟 log 是对某个时间点的记录不同,span 记录的是一个时间段。当程序开始执行一系列任务时,span 就会开始,当这一系列任务结束后,span 也随之结束。
由此可见,tracing 其实不仅仅是一个日志库,它还是一个分布式追踪的库,可以帮助我们采集信息,然后上传给 jaeger 等分布式追踪平台,最终实现对指定应用程序的监控。
在理解后,再来看看该如何为自定义的 logger 实现 spans。
打地基(2)
先来创建一个外部 span 和一个内部 span,从概念上来说,spans 和 events 创建的东东类似以下嵌套结构:
- 进入外部 span
- 进入内部 span
- 事件已创建,内部 span 是它的父 span,外部 span 是它的祖父 span
- 结束内部 span
- 进入内部 span
- 结束外部 span
有些同学可能还是不太理解,你就把 span 理解成为监控埋点,进入 span == 埋点开始,结束 span == 埋点结束
在下面的代码中,当使用 span.enter() 创建的 span 超出作用域时,将自动退出:根据 Drop 特征触发的顺序,inner_span 将先退出,然后才是 outer_span 的退出。
// in examples/figure_5/main.rs use tracing::{debug_span, info, info_span}; use tracing_subscriber::prelude::*; mod custom_layer; use custom_layer::CustomLayer; fn main() { tracing_subscriber::registry().with(CustomLayer).init(); let outer_span = info_span!("outer", level = 0); let _outer_entered = outer_span.enter(); let inner_span = debug_span!("inner", level = 1); let _inner_entered = inner_span.enter(); info!(a_bool = true, answer = 42, message = "first example"); }
再回到事件处理部分,通过使用 examples/figure_0/main.rs 我们能获取到事件的父 span,当然,前提是它存在。但是在实际场景中,直接使用 ctx.event_scope(event) 来迭代所有 span 会更加简单好用。
注意,这种迭代顺序类似于栈结构,以上面的代码为例,先被迭代的是 inner_span,然后才是 outer_span。
当然,如果你不想以类似于出栈的方式访问,还可以使用 scope.from_root() 直接反转,此时的访问将从最外层开始: outer -> innter。
对了,为了使用 ctx.event_scope(),我们的订阅者还需实现 LookupRef。提前给出免责声明:这里的实现方式有些诡异,大家可能难以理解,但是..我们其实也无需理解,只要这么用即可。
译者注:这里用到了高阶生命周期 HRTB( Higher Ranke Trait Bounds ) 的概念,一般的读者无需了解,感兴趣的可以看看(这里)[https://doc.rust-lang.org/nomicon/hrtb.html]
#![allow(unused)] fn main() { // in examples/figure_5/custom_layer.rs impl<S> Layer<S> for CustomLayer where S: tracing::Subscriber, // 好可怕! 还好我们不需要理解它,只要使用即可 S: for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>, { fn on_event(&self, event: &tracing::Event<'_>, ctx: tracing_subscriber::layer::Context<'_, S>) { // 父 span let parent_span = ctx.event_span(event).unwrap(); println!("parent span"); println!(" name={}", parent_span.name()); println!(" target={}", parent_span.metadata().target()); println!(); // 迭代范围内的所有的 spans let scope = ctx.event_scope(event).unwrap(); for span in scope.from_root() { println!("an ancestor span"); println!(" name={}", span.name()); println!(" target={}", span.metadata().target()); } } } }
运行下看看效果:
$ cargo run --example figure_5
parent span
name=inner
target=figure_5
an ancestor span
name=outer
target=figure_5
an ancestor span
name=inner
target=figure_5
细心的同学可能会发现,这里怎么也没有 field 数据?没错,而且恰恰是这些 field 包含的数据才让日志和监控有意义。那我们可以像之前一样,使用访问器 Visitor 来解决吗?
span 的数据在哪里
答案是:No。因为 ctx.event_scope 返回的东东没有任何办法可以访问其中的字段。
不知道大家还记得我们为何之前要使用访问器吗?很简单,因为 tracing 默认不会去存储数据,既然如此,那 span 这种跨了某个时间段的,就更不可能去存储数据了。
现在只能看看 Layer 特征有没有提供其它的方法了,哦呦,发现了一个 on_new_span,从名字可以看出,该方法是在 span 创建时调用的。
#![allow(unused)] fn main() { // in examples/figure_6/custom_layer.rs impl<S> Layer<S> for CustomLayer where S: tracing::Subscriber, S: for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>, { fn on_new_span( &self, attrs: &tracing::span::Attributes<'_>, id: &tracing::span::Id, ctx: tracing_subscriber::layer::Context<'_, S>, ) { let span = ctx.span(id).unwrap(); println!("Got on_new_span!"); println!(" level={:?}", span.metadata().level()); println!(" target={:?}", span.metadata().target()); println!(" name={:?}", span.metadata().name()); // Our old friend, `println!` exploration. let mut visitor = PrintlnVisitor; attrs.record(&mut visitor); } } }
$ cargo run --example figure_6
Got on_new_span!
level=Level(Info)
target="figure_7"
name="outer"
field=level value=0
Got on_new_span!
level=Level(Debug)
target="figure_7"
name="inner"
field=level value=1
芜湖! 我们的数据回来了!但是这里有一个隐患:只能在创建的时候去访问数据。如果仅仅是为了记录 spans,那没什么大问题,但是如果我们随后需要记录事件然后去尝试访问之前的 span 呢?此时 span 的数据已经不存在了!
如果 tracing 不能存储数据,那我们这些可怜的开发者该怎么办?
自己存储 span 数据
何为一个优秀的程序员?能偷懒的时候绝不多动半跟手指,但是需要勤快的时候,也是自己动手丰衣足食的典型。
因此,既然 tracing 不支持,那就自己实现吧。先确定一个目标:捕获 span 的数据,然后存储在某个地方以便后续访问。
好在 tracing-subscriber 提供了扩展 extensions 的方式,可以让我们轻松地存储自己的数据,该扩展甚至可以跟每一个 span 联系在一起!
虽然我们可以把之前见过的 BTreeMap<String, serde_json::Value> 存在扩展中,但是由于扩展数据是被 registry 中的所有layers 所共享的,因此出于私密性的考虑,还是只保存私有字段比较合适。这里使用一个 newtype 模式来创建新的类型:
#![allow(unused)] fn main() { // in examples/figure_8/custom_layer.rs #[derive(Debug)] struct CustomFieldStorage(BTreeMap<String, serde_json::Value>); }
每次发现一个新的 span 时,都基于它来构建一个 JSON 对象,然后将其存储在扩展数据中。
#![allow(unused)] fn main() { // in examples/figure_8/custom_layer.rs fn on_new_span( &self, attrs: &tracing::span::Attributes<'_>, id: &tracing::span::Id, ctx: tracing_subscriber::layer::Context<'_, S>, ) { // 基于 field 值来构建我们自己的 JSON 对象 let mut fields = BTreeMap::new(); let mut visitor = JsonVisitor(&mut fields); attrs.record(&mut visitor); // 使用之前创建的 newtype 包裹下 let storage = CustomFieldStorage(fields); // 获取内部 span 数据的引用 let span = ctx.span(id).unwrap(); // 获取扩展,用于存储我们的 span 数据 let mut extensions = span.extensions_mut(); // 存储! extensions.insert::<CustomFieldStorage>(storage); } }
这样,未来任何时候我们都可以取到该 span 包含的数据( 例如在 on_event 方法中 )。
#![allow(unused)] fn main() { // in examples/figure_8/custom_layer.rs fn on_event(&self, event: &tracing::Event<'_>, ctx: tracing_subscriber::layer::Context<'_, S>) { let scope = ctx.event_scope(event).unwrap(); println!("Got event!"); for span in scope.from_root() { let extensions = span.extensions(); let storage = extensions.get::<CustomFieldStorage>().unwrap(); println!(" span"); println!(" target={:?}", span.metadata().target()); println!(" name={:?}", span.metadata().name()); println!(" stored fields={:?}", storage); } } }
功能齐全的 JSON logger
截至目前,我们已经学了不少东西,下面来利用这些知识实现最后的 JSON logger。
#![allow(unused)] fn main() { // in examples/figure_9/custom_layer.rs fn on_event(&self, event: &tracing::Event<'_>, ctx: tracing_subscriber::layer::Context<'_, S>) { // All of the span context let scope = ctx.event_scope(event).unwrap(); let mut spans = vec![]; for span in scope.from_root() { let extensions = span.extensions(); let storage = extensions.get::<CustomFieldStorage>().unwrap(); let field_data: &BTreeMap<String, serde_json::Value> = &storage.0; spans.push(serde_json::json!({ "target": span.metadata().target(), "name": span.name(), "level": format!("{:?}", span.metadata().level()), "fields": field_data, })); } // The fields of the event let mut fields = BTreeMap::new(); let mut visitor = JsonVisitor(&mut fields); event.record(&mut visitor); // And create our output let output = serde_json::json!({ "target": event.metadata().target(), "name": event.metadata().name(), "level": format!("{:?}", event.metadata().level()), "fields": fields, "spans": spans, }); println!("{}", serde_json::to_string_pretty(&output).unwrap()); } }
$ cargo run --example figure_9
{
"fields": {
"a_bool": true,
"answer": 42,
"message": "first example"
},
"level": "Level(Info)",
"name": "event examples/figure_9/main.rs:16",
"spans": [
{
"fields": {
"level": 0
},
"level": "Level(Info)",
"name": "outer",
"target": "figure_9"
},
{
"fields": {
"level": 1
},
"level": "Level(Debug)",
"name": "inner",
"target": "figure_9"
}
],
"target": "figure_9"
}
嗯,完美。
等等,你说功能齐全?
上面的代码在发布到生产环境后,依然运行地相当不错,但是我发现还缺失了一个功能: span 在创建之后,依然要能记录数据。
#![allow(unused)] fn main() { // in examples/figure_10/main.rs let outer_span = info_span!("outer", level = 0, other_field = tracing::field::Empty); let _outer_entered = outer_span.enter(); // Some code... outer_span.record("other_field", &7); }
如果基于之前的代码运行上面的代码,我们将不会记录 other_field,因为该字段在收到 on_new_span 事件时,还不存在。
对此,Layer 提供了 on_record 方法:
#![allow(unused)] fn main() { // in examples/figure_10/custom_layer.rs fn on_record( &self, id: &tracing::span::Id, values: &tracing::span::Record<'_>, ctx: tracing_subscriber::layer::Context<'_, S>, ) { // 获取正在记录数据的 span let span = ctx.span(id).unwrap(); // 获取数据的可变引用,该数据是在 on_new_span 中创建的 let mut extensions_mut = span.extensions_mut(); let custom_field_storage: &mut CustomFieldStorage = extensions_mut.get_mut::<CustomFieldStorage>().unwrap(); let json_data: &mut BTreeMap<String, serde_json::Value> = &mut custom_field_storage.0; // 使用我们的访问器老朋友 let mut visitor = JsonVisitor(json_data); values.record(&mut visitor); } }
终于,在最后,我们拥有了一个功能齐全的自定义的 JSON logger,大家快去尝试下吧。当然,你也可以根据自己的需求来定制专属于你的 logger,毕竟方法是一通百通的。
在以下 github 仓库,可以找到完整的代码: https://github.com/bryanburgers/tracing-blog-post
本文由 Rustt 提供翻译 原文链接: https://github.com/studyrs/Rustt/blob/main/Articles/%5B2022-04-07%5D%20在%20Rust%20中使用%20tracing%20自定义日志.md
监控
监控是一个很大的领域,大到老板、前端开发、后端开发理解的监控可能都不相同。
- 老板眼中的监控:业务大数据实时展示
- 前端眼中的监控:手机 APP 收集上来的异常、崩溃、用户操作日志等
- 后端眼中的监控:请求链路跟踪、一段时间内的请求错误率、QPS 过高、异常日志等
正是因为这些复杂性,导致很多同学难以准确的说出监控到底是什么。
下面,我们将试图解释清楚监控的概念,并引入一个全新的概念:可观测性。
可观测性
在监控章节的引言中,我们提到了老板、前端、后端眼中的监控是各不相同的,那么有没有办法将监控模型进行抽象、统一呢?
来简单分析一下:
- 业务指标实时展示,这是一个指标型的数据( metric )
- 手机 APP 上传的数据,包含了日志( log )和指标类型( metric ),如果考虑到 APP 作为一次 HTTP 请求的发起端,那还涉及到请求链路的跟踪( trace)
- 后端链路跟踪是 trace,请求错误率、QPS 是 metric,异常日志是 log
喔,好像线索很明显哎,我们貌似可以把监控模型分为三种:指标 metric、日志 log 和 链路 trace。
先别急,我们对总结出来的三种类型进行下对比,看看彼此之间是否存在关联性( 良好的模型设计,彼此之间应该是无关联的 ):
- 指标:用于表示在某一段时间内,一个行为出现的次数和分布
- 日志:记录在某一个时间点发生的一次事件
- 链路:记录一次请求所经过的完整的服务链路,可能会横跨线程、进程,也可能会横跨服务( 分布式、微服务 )
按照这个定义来看,三种类型几乎没有关联性,是不是意味着我们的监控模型非常成功?
恭喜你,刚才总结出的监控模型正是这几年非常火热的可观测性监控的三大基础:Metrics / Log / Trace。
各自为战的三种模型
但是如果按照这个模型,我们将监控分成三个部分开发,彼此没有关联,并且在使用之时,也带着孤立的观点去看待这些数据和功能,那可观测性就失去了其应有的意义。
例如要看指标趋势变化就使用 metrics,查看详细问题使用 log,要看请求链路、链路各部分的耗时、服务依赖都使用 trace,虽然看起来很美好,但是它们都在各自为战。
例如一个很常见的场景,现在我们通过 metrics 获得了一个告警,发现某个服务的 SLA 降低、错误率上升,此时该如何排查错误原因? 查看日志?你如何确保日志跟错误率上升有内在的联系呢?而且一个大型服务,它的各种类型的日志、错误都是非常频繁的,要大海捞针般地找出特定的日志,非常难。
由于缺乏数据模型上的关联,最后只能各自为战:发现了错误率上升,就人工去找日志和链路,运气好,就能很快地查明原因,运气不好?等待老板和用户的咆哮吧
这个过程很不美好,需要工程师们充分理解每一项数据的底层逻辑,而在大型微服务架构中,没有一个工程师可以清晰的知道所有的底层逻辑,此时就需要分工协作去排查,那问题处理的复杂度和挑战性最终会急剧增加。
模型纽带
看来,要解决这个问题,我们需要一个纽带,来把三个模型串联起来,目前来看,trace 是最适合的。
因为问题的跟踪和解决其实就是沿着数据的流向来的,我们只要在 trace 流动的过程中,在沿途把相关的 log 收集上来,然后再针对收到的各种 trace,根据其标签去统计相应的指标。
这样,是不是就成功地将三个模型关联在了一起?而且还不是强扭的瓜!
再回到之前假设的场景:当我们对某个 Metric 波动发生兴趣时,可以直接将造成此波动的 Trace 关联检索出来,然后查看这些 Trace 在各个微服务中的所有执行细节,最后发现是底层某个微服务在执行请求过程中发生了 Panic,这个错误不断向上传播导致了服务对外 SLA 下降。
如果可观测平台做得更完善一些,将微服务的变更事件数据也呈现出来,那么一个工程师就可以快速完成整个排障和根因定位的过程,甚至不需要人,通过机器就可以自动完成整个排障和根因定位过程。
看到这里,相信大家都已经明白了 trace 的重要性以及可观测性监控到底优秀在哪里。那么问题来了,该如何落地?
数据采集
首先,没有数据,就没有一切,因此我们需要先把监控数据采集上来。
除了跨服务的数据统一规范外,由于现在的微服务往往使用多种语言实现,我们的数据采集还要支持不同的语言,选择一个合适的数据采集 SDK 就成了重中之重。
目前来说,我们最推荐大家采用 OpenTelemetry 作为可观测性解决方案,它提供了完整的数据协议规范、API和多语言采集 SDK,我们将在下个章节进行详细介绍。
数据处理和存储
虽然在我们之前的模型设计完善后,数据彼此之间存在内在关联性,但是不代表它们就能够按照同样的格式来存储了,甚至都无法保证使用同一个数据库来存储。
就目前而言,对于三种模型的数据处理和存储推荐如下:
- Trace,使用 jaeger 接收采集上来的 trace 数据,经过处理后存储到一个分布式数据库中,例如 cassandra、scyllaDB 等
- Log,如果对日志的关键词索引有较高的要求,还是建议使用 ElasticeSearch,如果可以提前在日志中通过 kv 的形式打上标签,然后未来也只需要通过标签来索引,那可以考虑使用 loki
- Metrics,啥都不用说了,prometheus 走起,当然还可以使用 influxdb,后者正在使用 Rust 重写,期待未来的一飞冲天
数据查询和展示
大家知道可观测性现在为什么很多人搞不清楚吗?就是因为你怎么做都可以,比如之前的存储,就有很多解决方案,而且还都不错。
对于数据展示也是,你可以使用上面的 jaeger、prometheus 自带的 UI,也可以使用 grafana 这种统一性的 UI,而从我个人来说,更推荐使用 grafana,毕竟 UI 的统一性和内联性对于监控数据的查询是非常重要的。
再说了,grafana 的 UI 做的好看啊,没人能拒绝美好的事物吧 :D
好了,一篇口水文终于结束了,在后续章节我们将学习如何使用 OpenTelemetry + Jaeger + Prometheus + Grafana 搭建一套可用的监控服务,先来看看如何搭建和使用分布式追踪监控。
"tracing 呢?你这个监控服务怎么没有它的身影,日志章节口口声声的爱,现在就忘记了吗?"
"别急,我还记得呢,先卖个关子"
分布式追踪
附录 A:关键字
下面的列表包含 Rust 中正在使用或者以后会用到的关键字。因此,这些关键字不能被用作标识符(除了原生标识符),包括函数、变量、参数、结构体字段、模块、包、常量、宏、静态值、属性、类型、特征或生命周期。
目前正在使用的关键字
如下关键字目前有对应其描述的功能。
as- 强制类型转换,或use和extern crate包和模块引入语句中的重命名break- 立刻退出循环const- 定义常量或原生常量指针(constant raw pointer)continue- 继续进入下一次循环迭代crate- 链接外部包dyn- 动态分发特征对象else- 作为if和if let控制流结构的 fallbackenum- 定义一个枚举类型extern- 链接一个外部包,或者一个宏变量(该变量定义在另外一个包中)false- 布尔值falsefn- 定义一个函数或 函数指针类型 (function pointer type)for- 遍历一个迭代器或实现一个 trait 或者指定一个更高级的生命周期if- 基于条件表达式的结果来执行相应的分支impl- 为结构体或者特征实现具体功能in-for循环语法的一部分let- 绑定一个变量loop- 无条件循环match- 模式匹配mod- 定义一个模块move- 使闭包获取其所捕获项的所有权mut- 在引用、裸指针或模式绑定中使用,表明变量是可变的pub- 表示结构体字段、impl块或模块的公共可见性ref- 通过引用绑定return- 从函数中返回Self- 实现特征类型的类型别名self- 表示方法本身或当前模块static- 表示全局变量或在整个程序执行期间保持其生命周期struct- 定义一个结构体super- 表示当前模块的父模块trait- 定义一个特征true- 布尔值truetype- 定义一个类型别名或关联类型unsafe- 表示不安全的代码、函数、特征或实现use- 在当前代码范围内(模块或者花括号对)引入外部的包、模块等where- 表示一个约束类型的从句while- 基于一个表达式的结果判断是否继续循环
保留做将来使用的关键字
如下关键字没有任何功能,不过由 Rust 保留以备将来的应用。
abstractasyncawaitbecomeboxdofinalmacrooverrideprivtrytypeofunsizedvirtualyield
原生标识符
原生标识符(Raw identifiers)允许你使用通常不能使用的关键字,其带有 r# 前缀。
例如,match 是关键字。如果尝试编译如下使用 match 作为名字的函数:
fn match(needle: &str, haystack: &str) -> bool {
haystack.contains(needle)
}
会得到这个错误:
error: expected identifier, found keyword `match`
--> src/main.rs:4:4
|
4 | fn match(needle: &str, haystack: &str) -> bool {
| ^^^^^ expected identifier, found keyword
该错误表示你不能将关键字 match 用作函数标识符。你可以使用原生标识符将 match 作为函数名称使用:
文件名: src/main.rs
fn r#match(needle: &str, haystack: &str) -> bool { haystack.contains(needle) } fn main() { assert!(r#match("foo", "foobar")); }
此代码编译没有任何错误。注意 r# 前缀需同时用于函数名定义和 main 函数中的调用。
原生标识符允许使用你选择的任何单词作为标识符,即使该单词恰好是保留关键字。 此外,原生标识符允许你使用其它 Rust 版本编写的库。比如,try 在 Rust 2015 edition 中不是关键字,却在 Rust 2018 edition 是关键字。所以如果用 2015 edition 编写的库中带有 try 函数,在 2018 edition 中调用时就需要使用原始标识符语法,在这里是 r#try。
附录 B:运算符与符号
该附录包含了 Rust 目前出现过的各种符号,这些符号之前都分散在各个章节中。
运算符
表 B-1 包含了 Rust 中的运算符、上下文中的示例、简短解释以及该运算符是否可重载。如果一个运算符是可重载的,则该运算符上用于重载的特征也会列出。
下表中,expr 是表达式,ident 是标识符,type 是类型,var 是变量,trait 是特征,pat 是匹配分支(pattern)。
表 B-1:运算符
| 运算符 | 示例 | 解释 | 是否可重载 |
|---|---|---|---|
! | ident!(...), ident!{...}, ident![...] | 宏展开 | |
! | !expr | 按位非或逻辑非 | Not |
!= | var != expr | 不等比较 | PartialEq |
% | expr % expr | 算术求余 | Rem |
%= | var %= expr | 算术求余与赋值 | RemAssign |
& | &expr, &mut expr | 借用 | |
& | &type, &mut type, &'a type, &'a mut type | 借用指针类型 | |
& | expr & expr | 按位与 | BitAnd |
&= | var &= expr | 按位与及赋值 | BitAndAssign |
&& | expr && expr | 逻辑与 | |
* | expr * expr | 算术乘法 | Mul |
*= | var *= expr | 算术乘法与赋值 | MulAssign |
* | *expr | 解引用 | |
* | *const type, *mut type | 裸指针 | |
+ | trait + trait, 'a + trait | 复合类型限制 | |
+ | expr + expr | 算术加法 | Add |
+= | var += expr | 算术加法与赋值 | AddAssign |
, | expr, expr | 参数以及元素分隔符 | |
- | - expr | 算术取负 | Neg |
- | expr - expr | 算术减法 | Sub |
-= | var -= expr | 算术减法与赋值 | SubAssign |
-> | fn(...) -> type, |...| -> type | 函数与闭包,返回类型 | |
. | expr.ident | 成员访问 | |
.. | .., expr.., ..expr, expr..expr | 右半开区间 | PartialOrd |
..= | ..=expr, expr..=expr | 闭合区间 | PartialOrd |
.. | ..expr | 结构体更新语法 | |
.. | variant(x, ..), struct_type { x, .. } | “代表剩余部分”的模式绑定 | |
... | expr...expr | (不推荐使用,用..=替代) 闭合区间 | |
/ | expr / expr | 算术除法 | Div |
/= | var /= expr | 算术除法与赋值 | DivAssign |
: | pat: type, ident: type | 约束 | |
: | ident: expr | 结构体字段初始化 | |
: | 'a: loop {...} | 循环标志 | |
; | expr; | 语句和语句结束符 | |
; | [...; len] | 固定大小数组语法的部分 | |
<< | expr << expr | 左移 | Shl |
<<= | var <<= expr | 左移与赋值 | ShlAssign |
< | expr < expr | 小于比较 | PartialOrd |
<= | expr <= expr | 小于等于比较 | PartialOrd |
= | var = expr, ident = type | 赋值/等值 | |
== | expr == expr | 等于比较 | PartialEq |
=> | pat => expr | 匹配分支语法的部分 | |
> | expr > expr | 大于比较 | PartialOrd |
>= | expr >= expr | 大于等于比较 | PartialOrd |
>> | expr >> expr | 右移 | Shr |
>>= | var >>= expr | 右移与赋值 | ShrAssign |
@ | ident @ pat | 模式绑定 | |
^ | expr ^ expr | 按位异或 | BitXor |
^= | var ^= expr | 按位异或与赋值 | BitXorAssign |
| | pat | pat | 模式匹配中的多个可选条件 | |
| | expr | expr | 按位或 | BitOr |
|= | var |= expr | 按位或与赋值 | BitOrAssign |
|| | expr || expr | 逻辑或 | |
? | expr? | 错误传播 |
非运算符符号
表 B-2:独立语法
| 符号 | 解释 |
|---|---|
'ident | 生命周期名称或循环标签 |
...u8, ...i32, ...f64, ...usize, 等 | 指定类型的数值常量 |
"..." | 字符串常量 |
r"...", r#"..."#, r##"..."##, etc. | 原生字符串, 未转义字符 |
b"..." | 将 &str 转换成 &[u8; N] 类型的数组 |
br"...", br#"..."#, br##"..."##, 等 | 原生字节字符串,原生和字节字符串字面值的结合 |
'...' | Char 字符 |
b'...' | ASCII 字节 |
|...| expr | 闭包 |
! | 代表总是空的类型,用于发散函数(无返回值函数) |
_ | 模式绑定中表示忽略的意思;也用于增强整型字面值的可读性 |
表 B-3 展示了模块和对象调用路径的语法。
表 B-3:路径相关语法
| 符号 | 解释 |
|---|---|
ident::ident | 命名空间路径 |
::path | 从当前的包的根路径开始的相对路径 |
self::path | 与当前模块相对的路径(如一个显式相对路径) |
super::path | 与父模块相对的路径 |
type::ident, <type as trait>::ident | 关联常量、关联函数、关联类型 |
<type>::... | 不可以被直接命名的关联项类型(如 <&T>::...,<[T]>::..., 等) |
trait::method(...) | 使用特征名进行方法调用,以消除方法调用的二义性 |
type::method(...) | 使用类型名进行方法调用, 以消除方法调用的二义性 |
<type as trait>::method(...) | 将类型转换为特征,再进行方法调用,以消除方法调用的二义性 |
表 B-4 展示了使用泛型参数时用到的符号。
表 B-4:泛型
| 符号 | 解释 |
|---|---|
path<...> | 为一个类型中的泛型指定具体参数(如 Vec<u8>) |
path::<...>, method::<...> | 为一个泛型、函数或表达式中的方法指定具体参数,通常指双冒号(turbofish)(如 "42".parse::<i32>()) |
fn ident<...> ... | 泛型函数定义 |
struct ident<...> ... | 泛型结构体定义 |
enum ident<...> ... | 泛型枚举定义 |
impl<...> ... | 实现泛型 |
for<...> type | 高阶生命周期限制 |
type<ident=type> | 泛型,其一个或多个相关类型必须被指定为特定类型(如 Iterator<Item=T>) |
表 B-5 展示了使用特征约束来限制泛型参数的符号。
表 B-5:特征约束
| 符号 | 解释 |
|---|---|
T: U | 泛型参数 T需实现U类型 |
T: 'a | 泛型 T 的生命周期必须长于 'a(意味着该类型不能传递包含生命周期短于 'a 的任何引用) |
T : 'static | 泛型 T 只能使用声明周期为'static 的引用 |
'b: 'a | 生命周期'b必须长于生命周期'a |
T: ?Sized | 使用一个不定大小的泛型类型 |
'a + trait, trait + trait | 多个类型组成的复合类型限制 |
表 B-6 展示了宏以及在一个对象上定义属性的符号。
表 B-6:宏与属性
| 符号 | 解释 |
|---|---|
#[meta] | 外部属性 |
#![meta] | 内部属性 |
$ident | 宏替换 |
$ident:kind | 宏捕获 |
$(…)… | 宏重复 |
ident!(...), ident!{...}, ident![...] | 宏调用 |
表 B-7 展示了写注释的符号。
表 B-7:注释
| 符号 | 注释 |
|---|---|
// | 行注释 |
//! | 内部行(hang)文档注释 |
/// | 外部行文档注释 |
/*...*/ | 块注释 |
/*!...*/ | 内部块文档注释 |
/**...*/ | 外部块文档注释 |
表 B-8 展示了出现在使用元组时的符号。
表 B-8:元组
| 符号 | 解释 |
|---|---|
() | 空元组(亦称单元),即是字面值也是类型 |
(expr) | 括号表达式 |
(expr,) | 单一元素元组表达式 |
(type,) | 单一元素元组类型 |
(expr, ...) | 元组表达式 |
(type, ...) | 元组类型 |
expr(expr, ...) | 函数调用表达式;也用于初始化元组结构体 struct 以及元组枚举 enum 变体 |
expr.0, expr.1, etc. | 元组索引 |
表 B-9 展示了使用大括号的上下文。
表 B-9:大括号
| 符号 | 解释 |
|---|---|
{...} | 代码块表达式 |
Type {...} | 结构体字面值 |
表 B-10 展示了使用方括号的上下文。
表 B-10:方括号
| 符号 | 解释 |
|---|---|
[...] | 数组 |
[expr; len] | 数组里包含len个expr |
[type; len] | 数组里包含了len个type类型的对象 |
expr[expr] | 集合索引。 重载(Index, IndexMut) |
expr[..], expr[a..], expr[..b], expr[a..b] | 集合索引,也称为集合切片,索引要实现以下特征中的其中一个:Range,RangeFrom,RangeTo 或 RangeFull |
附录 C:表达式
在语句与表达式章节中,我们对表达式有过介绍,下面对这些常用表达式进行一一说明。
基本表达式
#![allow(unused)] fn main() { let n = 3; let s = "test"; }
if 表达式
fn main() { let var1 = 10; let var2 = if var1 >= 10 { var1 } else { var1 + 10 }; println!("{}", var2); }
通过 if 表达式将值赋予 var2。
你还可以在循环中结合 continue 、break 来使用:
#![allow(unused)] fn main() { let mut v = 0; for i in 1..10 { v = if i == 9 { continue } else { i } } println!("{}", v); }
if let 表达式
#![allow(unused)] fn main() { let o = Some(3); let v = if let Some(x) = o { x } else { 0 }; }
match 表达式
#![allow(unused)] fn main() { let o = Some(3); let v = match o { Some(x) => x, _ => 0 }; }
loop 表达式
#![allow(unused)] fn main() { let mut n = 0; let v = loop { if n == 10 { break n } n += 1; }; }
语句块 {}
#![allow(unused)] fn main() { let mut n = 0; let v = { println!("before: {}", n); n += 1; println!("after: {}", n); n }; println!("{}", v); }
附录 D:派生特征 trait
在本书的各个部分中,我们讨论了可应用于结构体和枚举定义的 derive 属性。被 derive 标记的对象会自动实现对应的默认特征代码,继承相应的功能。
在本附录中,我们列举了所有标准库存在的 derive 特征,每个特征覆盖了以下内容
- 该特征将会派生什么样的操作符和方法
- 由
derive提供什么样的特征实现 - 实现特征对于类型意味着什么
- 你需要什么条件来实现该特征
- 特征示例
如果你希望不同于 derive 属性所提供的行为,请查阅 标准库文档 中每个特征的细节以了解如何手动实现它们。
除了本文列出的特征之外,标准库中定义的其它特征不能通过 derive 在类型上实现。这些特征不存在有意义的默认行为,所以由你负责以合理的方式实现它们。
一个无法被派生的特征例子是为终端用户处理格式化的 Display 。你应该时常考虑使用合适的方法来为终端用户显示一个类型。终端用户应该看到类型的什么部分?他们会找出相关部分吗?对他们来说最关心的数据格式是什么样的?Rust 编译器没有这样的洞察力,因此无法为你提供合适的默认行为。
本附录所提供的可派生特征列表其实并不全面:库可以为其内部的特征实现 derive ,因此除了本文列出的标准库 derive 之外,还有很多很多其它库的 derive 。实现 derive 涉及到过程宏的应用,这在宏章节中有介绍。
用于开发者输出的 Debug
Debug 特征可以让指定对象输出调试格式的字符串,通过在 {} 占位符中增加 :? 表明,例如println!("show you some debug info: {:?}", MyObject);.
Debug 特征允许以调试为目的来打印一个类型的实例,所以程序员可以在执行过程中看到该实例的具体信息。
例如,在使用 assert_eq! 宏时, Debug 特征是必须的。如果断言失败,这个宏就把给定实例的值打印出来,这样程序员就能看到两个实例为什么不相等。
等值比较的 PartialEq 和 Eq
PartialEq 特征可以比较一个类型的实例以检查是否相等,并开启了 == 和 != 运算符的功能。
派生的 PartialEq 实现了 eq 方法。当 PartialEq 在结构体上派生时,只有所有 的字段都相等时两个实例才相等,同时只要有任何字段不相等则两个实例就不相等。当在枚举上派生时,每一个成员都和其自身相等,且和其他成员都不相等。
例如,当使用 assert_eq! 宏时,需要比较一个类型的两个实例是否相等,则 PartialEq 特征是必须的。
Eq 特征没有方法, 其作用是表明每一个被标记类型的值都等于其自身。 Eq 特征只能应用于那些实现了 PartialEq 的类型,但并非所有实现了 PartialEq 的类型都可以实现 Eq。浮点类型就是一个例子:浮点数的实现表明两个非数字( NaN ,not-a-number)值是互不相等的。
例如,对于一个 HashMap<K, V> 中的 key 来说, Eq 是必须的,这样 HashMap<K, V> 就可以知道两个 key 是否一样。
次序比较的 PartialOrd 和 Ord
PartialOrd 特征可以让一个类型的多个实例实现排序功能。实现了 PartialOrd 的类型可以使用 <、 >、<= 和 >= 操作符。一个类型想要实现 PartialOrd 的前提是该类型已经实现了 PartialEq 。
派生 PartialOrd 实现了 partial_cmp 方法,一般情况下其返回一个 Option<Ordering>,但是当给定的值无法进行排序时将返回 None。尽管大多数类型的值都可以比较,但一个无法产生顺序的例子是:浮点类型的非数字值。当在浮点数上调用 partial_cmp 时, NaN 的浮点数将返回 None。
当在结构体上派生时, PartialOrd 以在结构体定义中字段出现的顺序比较每个字段的值来比较两个实例。当在枚举上派生时,认为在枚举定义中声明较早的枚举项小于其后的枚举项。
例如,对于来自于 rand 包的 gen_range 方法来说,当在一个大值和小值指定的范围内生成一个随机值时, PartialOrd trait 是必须的。
对于派生了 Ord 特征的类型,任何两个该类型的值都能进行排序。 Ord 特征实现了 cmp 方法,它返回一个 Ordering 而不是 Option<Ordering>,因为总存在一个合法的顺序。一个类型要想使用 Ord 特征,它必须要先实现 PartialOrd 和 Eq 。当在结构体或枚举上派生时, cmp 方法 和 PartialOrd 的 partial_cmp 方法表现是一致的。
例如,当在 BTreeSet<T>(一种基于有序值存储数据的数据结构)上存值时, Ord 是必须的。
复制值的 Clone 和 Copy
Clone 特征用于创建一个值的深拷贝(deep copy),复制过程可能包含代码的执行以及堆上数据的复制。查阅 通过 Clone 进行深拷贝获取有关 Clone 的更多信息。
派生 Clone 实现了 clone 方法,当为整个的类型实现 Clone 时,在该类型的每一部分上都会调用 clone 方法。这意味着类型中所有字段或值也必须实现了 Clone,这样才能够派生 Clone 。
例如,当在一个切片(slice)上调用 to_vec 方法时, Clone 是必须的。切片只是一个引用,并不拥有其所包含的实例数据,但是从 to_vec 中返回的 Vector 需要拥有实例数据,因此, to_vec 需要在每个元素上调用 clone 来逐个复制。因此,存储在切片中的类型必须实现 Clone。
Copy 特征允许你通过只拷贝存储在栈上的数据来复制值(浅拷贝),而无需复制存储在堆上的底层数据。查阅 通过 Copy 复制栈数据 的部分来获取有关 Copy 的更多信息。
实际上 Copy 特征并不阻止你在实现时使用了深拷贝,只是,我们不应该这么做,毕竟遵循一个语言的惯例是很重要的。当用户看到 Copy 时,潜意识就应该知道这是浅拷贝,复制一个值会非常快。
当一个类型的内部字段全部实现了 Copy 时,你就可以在该类型上派上 Copy 特征。 一个类型如果要实现 Copy 它必须先实现 Clone ,因为一个类型实现 Clone 后,就等于顺便实现了 Copy 。
总之, Copy 拥有更好的性能,当浅拷贝足够的时候,就不要使用 Clone ,不然会导致你的代码运行更慢,对于性能优化来说,一个很大的方面就是减少热点路径深拷贝的发生。
固定大小的值映射的 Hash
Hash 特征允许你使用 hash 函数把一个任意大小的实例映射到一个固定大小的值上。派生 Hash 实现了 hash 方法,对某个类型进行 hash 调用,其实就是对该类型下每个字段单独进行 hash 调用,然后把结果进行汇总,这意味着该类型下的所有的字段也必须实现了 Hash,这样才能够派生 Hash。
例如,在 HashMap<K, V> 上存储数据,存放 key 的时候, Hash 是必须的。
默认值的 Default
Default 特征会帮你创建一个类型的默认值。 派生 Default 意味着自动实现了 default 函数。 default 函数的派生实现调用了类型每部分的 default 函数,这意味着类型中所有的字段也必须实现了 Default,这样才能够派生 Default 。
Default::default 函数通常结合结构体更新语法一起使用,这在第五章的 结构体更新语法 部分有讨论。可以自定义一个结构体的一小部分字段而剩余字段则使用 ..Default::default() 设置为默认值。
例如,当你在 Option<T> 实例上使用 unwrap_or_default 方法时, Default 特征是必须的。如果 Option<T> 是 None 的话, unwrap_or_default 方法将返回 T 类型的 Default::default 的结果。
附录 E:prelude 模块
附录 F:Rust 版本发布
Rust 版本说明
早在第一章,我们见过 cargo new 在 Cargo.toml 中增加了一些有关 edition 的元数据。本附录将解释其意义!
与其它语言相比,Rust 的更新迭代较为频繁(得益于精心设计过的发布流程以及 Rust 语言开发者团队管理):
- 每 6 周发布一个迭代版本
- 2 - 3 年发布一个新的大版本:每一个版本会结合已经落地的功能,并提供一个清晰的带有完整更新文档和工具的功能包。新版本会作为常规的 6 周发布过程的一部分发布。
好处在于,可以满足不同的用户群体的需求:
- 对于活跃的 Rust 用户,他们总是能很快获取到新的语言内容,毕竟,尝鲜是技术爱好者的共同特点:)
- 对于一般的用户,edition 的发布会告诉这些用户:Rust 语言相比上次大版本发布,有了重大的改进,值得一看
- 对于 Rust 语言开发者,可以让他们的工作成果更快的被世人所知,不必锦衣夜行
在本文档编写时,Rust 已经有三个版本:Rust 2015、2018、2021。本书基于 Rust 2021 edition 编写。
Cargo.toml 中的 edition 字段表明代码应该使用哪个版本编译。如果该字段不存在,其默认为 2021 以提供后向兼容性。
每个项目都可以选择不同于默认的 Rust 2021 edition 的版本。这样,版本可能会包含不兼容的修改,比如新版本中新增的关键字可能会与老代码中的标识符冲突并导致错误。不过,除非你选择应用这些修改,否则旧代码依然能够被编译,即便你升级了编译器版本。
所有 Rust 编译器都支持任何之前存在的编译器版本,并可以链接任何支持版本的包。编译器修改只影响最初的解析代码的过程。因此,如果你使用 Rust 2021 而某个依赖使用 Rust 2018,你的项目仍旧能够编译并使用该依赖。反之,若项目使用 Rust 2018 而依赖使用 Rust 2021 亦可工作。
有一点需要明确:大部分功能在所有版本中都能使用。开发者使用任何 Rust 版本将能继续接收最新稳定版的改进。然而在一些情况,主要是增加了新关键字的时候,则可能出现了只能用于新版本的功能。只需切换版本即可利用新版本的功能。
请查看 Edition Guide 了解更多细节,这是一个完全介绍版本的书籍,包括如何通过 cargo fix 自动将代码迁移到新版本。
Rust 自身开发流程
本附录介绍 Rust 语言自身是如何开发的以及这如何影响作为 Rust 开发者的你。
无停滞稳定
作为一个语言,Rust 十分 注重代码的稳定性。我们希望 Rust 成为你代码坚实的基础,假如持续地有东西在变,这个希望就实现不了。但与此同时,如果不能实验新功能的话,在发布之前我们又无法发现其中重大的缺陷,而一旦发布便再也没有修改的机会了。
对于这个问题我们的解决方案被称为 “无停滞稳定”(“stability without stagnation”),其指导性原则是:无需担心升级到最新的稳定版 Rust。每次升级应该是无痛的,并应带来新功能,更少的 Bug 和更快的编译速度。
Choo, Choo! ~~ 小火车发布流程启动
开发 Rust 语言是基于一个火车时刻表来进行的:所有的开发工作在 Master 分支上完成,但是发布就像火车时刻表一样,拥有不同的时间,发布采用的软件发布列车模型,被用于思科 IOS 和等其它软件项目。Rust 有三个 发布通道(release channel):
- Nightly
- Beta
- Stable(稳定版)
大部分 Rust 开发者主要采用稳定版通道,不过希望实验新功能的开发者可能会使用 nightly 或 beta 版。
如下是一个开发和发布过程如何运转的例子:假设 Rust 团队正在进行 Rust 1.5 的发布工作。该版本发布于 2015 年 12 月,这个版本和时间显然比较老了,不过这里只是为了提供一个真实的版本。Rust 新增了一项功能:一个 master 分支的新提交。每天晚上,会产生一个新的 nightly 版本。每天都是发布版本的日子,而这些发布由发布基础设施自动完成。所以随着时间推移,发布轨迹看起来像这样,版本一天一发:
nightly: * - - * - - *
每 6 周时间,是准备发布新版本的时候了!Rust 仓库的 beta 分支会从用于 nightly 的 master 分支产生。现在,有了两个发布版本:
nightly: * - - * - - *
|
beta: *
大部分 Rust 用户不会主要使用 beta 版本,不过在 CI 系统中对 beta 版本进行测试能够帮助 Rust 发现可能的回归缺陷(regression)。同时,每天仍产生 nightly 发布:
nightly: * - - * - - * - - * - - *
|
beta: *
比如我们发现了一个回归缺陷。好消息是在这些缺陷流入稳定发布之前还有一些时间来测试 beta 版本!fix 被合并到 master,为此 nightly 版本得到了修复,接着这些 fix 将 backport 到 beta 分支,一个新的 beta 发布就产生了:
nightly: * - - * - - * - - * - - * - - *
|
beta: * - - - - - - - - *
第一个 beta 版的 6 周后,是发布稳定版的时候了!stable 分支从 beta 分支生成:
nightly: * - - * - - * - - * - - * - - * - * - *
|
beta: * - - - - - - - - *
|
stable: *
好的!Rust 1.5 发布了!然而,我们忘了些东西:因为又过了 6 周,我们还需发布 新版 Rust 的 beta 版,Rust 1.6。所以从 beta 分支生成 stable 分支后,新版的 beta 分支也再次从 nightly 生成:
nightly: * - - * - - * - - * - - * - - * - * - *
| |
beta: * - - - - - - - - * *
|
stable: *
这被称为 “train model”,因为每 6 周,一个版本 “离开车站”(“leaves the station”),不过从 beta 通道到达稳定通道还有一段旅程。
Rust 每 6 周发布一个版本,如时钟般准确。如果你知道了某个 Rust 版本的发布时间,就可以知道下个版本的时间:6 周后。每 6 周发布版本的一个好的方面是下一班车会来得更快。如果特定版本碰巧缺失某个功能也无需担心:另一个版本很快就会到来!这有助于减少因临近发版时间而偷偷释出未经完善的功能的压力。
多亏了这个过程,你总是可以切换到下一版本的 Rust 并验证是否可以轻易的升级:如果 beta 版不能如期工作,你可以向 Rust 团队报告并在发布稳定版之前得到修复!beta 版造成的破坏是非常少见的,不过 rustc 也不过是一个软件,可能会存在 Bug。
不稳定功能
这个发布模型中另一个值得注意的地方:不稳定功能(unstable features)。Rust 使用一个被称为 “功能标记”(“feature flags”)的技术来确定给定版本的某个功能是否启用。如果新功能正在积极地开发中,其提交到了 master,因此会出现在 nightly 版中,不过会位于一个 功能标记 之后。作为用户,如果你希望尝试这个正在开发的功能,则可以在源码中使用合适的标记来开启,不过必须使用 nightly 版。
如果使用的是 beta 或稳定版 Rust,则不能使用任何功能标记。这是在新功能被宣布为永久稳定之前获得实用价值的关键。这既满足了希望使用最尖端技术的同学,那些坚持稳定版的同学也知道其代码不会被破坏。这就是无停滞稳定。
本书只包含稳定的功能,因为还在开发中的功能仍可能改变,当其进入稳定版时肯定会与编写本书的时候有所不同。你可以在网上获取 nightly 版的文档。
Rustup 和 Rust Nightly 的职责
安装 Rust Nightly 版本
Rustup 使得改变不同发布通道的 Rust 更为简单,其在全局或分项目的层次工作。其默认会安装稳定版 Rust。例如为了安装 nightly:
$ rustup install nightly
你会发现 rustup 也安装了所有的 工具链(toolchains, Rust 和其相关组件)。如下是一位作者的 Windows 计算机上的例子:
> rustup toolchain list
stable-x86_64-pc-windows-msvc (default)
beta-x86_64-pc-windows-msvc
nightly-x86_64-pc-windows-msvc
在指定目录使用 Rust Nightly
如你所见,默认是稳定版。大部分 Rust 用户在大部分时间使用稳定版。你可能也会这么做,不过如果你关心最新的功能,可以为特定项目使用 nightly 版。为此,可以在项目目录使用 rustup override 来设置当前目录 rustup 使用 nightly 工具链:
$ cd ~/projects/needs-nightly
$ rustup override set nightly
现在,每次在 *~/需要 nightly 的项目/*下(在项目的根目录下,也就是 Cargo.toml 所在的目录) 调用 rustc 或 cargo,rustup 会确保使用 nightly 版 Rust。在你有很多 Rust 项目时大有裨益!
RFC 过程和团队
那么你如何了解这些新功能呢?Rust 开发模式遵循一个 Request For Comments (RFC) 过程。如果你希望改进 Rust,可以编写一个提议,也就是 RFC。
任何人都可以编写 RFC 来改进 Rust,同时这些 RFC 会被 Rust 团队评审和讨论,他们由很多不同分工的子团队组成。这里是 Rust 官网 上所有团队的总列表,其包含了项目中每个领域的团队:语言设计、编译器实现、基础设施、文档等。各个团队会阅读相应的提议和评论,编写回复,并最终达成接受或回绝功能的一致。
如果功能被接受了,在 Rust 仓库会打开一个 issue,人们就可以实现它。实现功能的人可能不是最初提议功能的人!当实现完成后,其会合并到 master 分支并位于一个特性开关(feature gate)之后,正如不稳定功能 部分所讨论的。
在稍后的某个时间,一旦使用 nightly 版的 Rust 团队能够尝试这个功能了,团队成员会讨论这个功能在 nightly 中运行的情况,并决定是否应该进入稳定版。如果决定继续推进,特性开关会移除,然后这个功能就被认为是稳定的了!乘着“发布的列车”,最终在新的稳定版 Rust 中出现。
附录 G:Rust 更新版本列表
本目录包含了 Rust 历次版本更新的重要内容解读,需要注意,每个版本实际更新的内容要比这里记录的更多,全部内容请访问每节开头的官方链接查看。
1.58
1.59
1.60
1.61
1.62
1.63
1.64
1.65
1.66
1.67
机器学习
机器学习(Machine Learning,ML)是指从有限的观测数据中学习(或“猜测”)出具有一般性的规律,并利用这些规律对未知数据进行预测的方法。 机器学习方法可以粗略地分为三个基本要素:模型、学习准则、优化算法.
- 模型:根据经验来假设一个函数集合 ℱ ,称为假设空间(Hypothesis Space),然后通过观测其在训练集 𝒟 上的特性,从中选择一个理想的假设(Hypothesis)𝑓∗∈ ℱ.通常分为线性&非线性.
- 学习准则:
- 损失函数:
- 风险最小化准则:经验风险最小化(Empirical Risk Minimization,ERM)准则与结构风险最小化(Structure Risk Minimization,SRM)准则. 前者是要求你对训练集的拟合,后者是保证泛化能力.
- 损失函数:
- 优化算法:在确定了训练集 𝒟 、假设空间 ℱ 以及学习准则后,如何找到最优的模型 𝑓(𝒙,𝜃∗) 就成了一个最优化(Optimization)问题.机器学习的训练过程其实就是最优化问题的求解过程.
传统的机器学习主要关注如何学习一个预测模型.一般需要首先将数据表示为一组特征(Feature),特征的表示形式可以是连续的数值、离散的符号或其他形式.然后将这些特征输入到预测模型,并输出预测结果.这类机器学习可以看作浅层学习(Shallow Learning).浅层学习的一个重要特点是不涉及特征学习,其特征主要靠人工经验或特征转换方法来抽取. 在实际任务中使用机器学习模型一般会包含以下几个步骤(如图1.2所示):
- 数据预处理:经过数据的预处理,如去除噪声等.比如在文本分类中,去除停用词等.
- 特征提取:从原始数据中提取一些有效的特征.比如在图像分类中,提取边缘、尺度不变特征变换(Scale Invariant Feature Transform,SIFT)特征等.
- 特征转换:对特征进行一定的加工,比如降维和升维.降维包括特征抽取(Feature Extraction)和特征选择(Feature Selection)两种途径.常用的特征转换方法有主成分分析(Principal Components Analysis,PCA)、线性判别分析(Linear Discriminant Analysis,LDA)等.
- 预测:机器学习的核心部分,学习一个函数并进行预测.
表示学习:为了提高机器学习系统的准确率,我们就需要将输入信息转换为有效的特征,或者更一般性地称为表示(Representation).如果有一种算法可以自动地学习出有效的特征,并提高最终机器学习模型的性能,那么这种学习就可以叫作表示学习(Representation Learning). 语义鸿沟:表示学习的关键是解决语义鸿沟(Semantic Gap)问题.即车这一概念是高层语义特征,轿车、自行车、卡车是底层特征,如何同不同车种提取出那个高层特征“车”呢?如果一个预测模型直接建立在底层特征之上,会导致对预测模型的能力要求过高.如果可以有一个好的表示在某种程度上能够反映出数据的高层语义特征,那么我们就能相对容易地构建后续的机器学习模型。
一. 绪论
1.1 机器学习的定义
正如我们根据过去的经验来判断明天的天气,吃货们希望从购买经验中挑选一个好瓜,那能不能让计算机帮助人类来实现这个呢?机器学习正是这样的一门学科,人的“经验”对应计算机中的“数据”,让计算机来学习这些经验数据,生成一个算法模型,在面对新的情况中,计算机便能作出有效的判断,这便是机器学习。
另一本经典教材的作者Mitchell给出了一个形式化的定义,假设:
- P:计算机程序在某任务类T上的性能。
- T:计算机程序希望实现的任务类。
- E:表示经验,即历史的数据集。
若该计算机程序通过利用经验E在任务T上获得了性能P的改善,则称该程序对E进行了学习。
1.2 机器学习的一些基本术语
假设我们收集了一批西瓜的数据,例如:(色泽=青绿;根蒂=蜷缩;敲声=浊响), (色泽=乌黑;根蒂=稍蜷;敲声=沉闷), (色泽=浅自;根蒂=硬挺;敲声=清脆)……每对括号内是一个西瓜的记录,定义:
- 所有记录的集合为:数据集。
- 每一条记录为:一个实例(instance)或样本(sample)。
- 例如:色泽或敲声,单个的特点为特征(feature)或属性(attribute)。
- 对于一条记录,如果在坐标轴上表示,每个西瓜都可以用坐标轴中的一个点表示,一个点也是一个向量,例如(青绿,蜷缩,浊响),即每个西瓜为:一个特征向量(feature vector)。
- 一个样本的特征数为:维数(dimensionality),该西瓜的例子维数为3,当维数非常大时,也就是现在说的“维数灾难”。
计算机程序学习经验数据生成算法模型的过程中,每一条记录称为一个“训练样本”,同时在训练好模型后,我们希望使用新的样本来测试模型的效果,则每一个新的样本称为一个“测试样本”。定义:
- 所有训练样本的集合为:训练集(trainning set),[特殊]。
- 所有测试样本的集合为:测试集(test set),[一般]。
- 机器学习出来的模型适用于新样本的能力为:泛化能力(generalization),即从特殊到一般。
西瓜的例子中,我们是想计算机通过学习西瓜的特征数据,训练出一个决策模型,来判断一个新的西瓜是否是好瓜。可以得知我们预测的是:西瓜是好是坏,即好瓜与差瓜两种,是离散值。同样地,也有通过历年的人口数据,来预测未来的人口数量,人口数量则是连续值。定义:
- 预测值为离散值的问题为:分类(classification)。
- 预测值为连续值的问题为:回归(regression)。
我们预测西瓜是否是好瓜的过程中,很明显对于训练集中的西瓜,我们事先已经知道了该瓜是否是好瓜,学习器通过学习这些好瓜或差瓜的特征,从而总结出规律,即训练集中的西瓜我们都做了标记,称为标记信息。但也有没有标记信息的情形,例如:我们想将一堆西瓜根据特征分成两个小堆,使得某一堆的西瓜尽可能相似,即都是好瓜或差瓜,对于这种问题,我们事先并不知道西瓜的好坏,样本没有标记信息。定义:
- 训练数据有标记信息的学习任务为:监督学习(supervised learning),容易知道上面所描述的分类和回归都是监督学习的范畴。
- 训练数据没有标记信息的学习任务为:无监督学习(unsupervised learning),常见的有聚类和关联规则。
二. 模型的评估与选择
2.1 误差与过拟合
我们将学习器对样本的实际预测结果与样本的真实值之间的差异成为:误差(error)。定义:
- 在训练集上的误差称为训练误差(training error)或经验误差(empirical error)。
- 在测试集上的误差称为测试误差(test error)。
- 学习器在所有新样本上的误差称为泛化误差(generalization error)。
显然,我们希望得到的是在新样本上表现得很好的学习器,即泛化误差小的学习器。因此,我们应该让学习器尽可能地从训练集中学出普适性的“一般特征”,这样在遇到新样本时才能做出正确的判别。然而,当学习器把训练集学得“太好”的时候,即把一些训练样本的自身特点当做了普遍特征;同时也有学习能力不足的情况,即训练集的基本特征都没有学习出来。我们定义:
- 学习能力过强,以至于把训练样本所包含的不太一般的特性都学到了,称为:过拟合(overfitting)。
- 学习能太差,训练样本的一般性质尚未学好,称为:欠拟合(underfitting)。
可以得知:在过拟合问题中,训练误差十分小,但测试误差教大;在欠拟合问题中,训练误差和测试误差都比较大。目前,欠拟合问题比较容易克服,例如增加迭代次数等,但过拟合问题还没有十分好的解决方案,过拟合是机器学习面临的关键障碍。

2.2 评估方法
在现实任务中,我们往往有多种算法可供选择,那么我们应该选择哪一个算法才是最适合的呢?如上所述,我们希望得到的是泛化误差小的学习器,理想的解决方案是对模型的泛化误差进行评估,然后选择泛化误差最小的那个学习器。但是,泛化误差指的是模型在所有新样本上的适用能力,我们无法直接获得泛化误差。
因此,通常我们采用一个“测试集”来测试学习器对新样本的判别能力,然后以“测试集”上的“测试误差”作为“泛化误差”的近似。显然:我们选取的测试集应尽可能与训练集互斥,下面用一个小故事来解释why:
假设老师出了10 道习题供同学们练习,考试时老师又用同样的这10道题作为试题,可能有的童鞋只会做这10 道题却能得高分,很明显:这个考试成绩并不能有效地反映出真实水平。回到我们的问题上来,我们希望得到泛化性能好的模型,好比希望同学们课程学得好并获得了对所学知识"举一反三"的能力;训练样本相当于给同学们练习的习题,测试过程则相当于考试。显然,若测试样本被用作训练了,则得到的将是过于"乐观"的估计结果。
2.3 训练集与测试集的划分方法
如上所述:我们希望用一个“测试集”的“测试误差”来作为“泛化误差”的近似,因此我们需要对初始数据集进行有效划分,划分出互斥的“训练集”和“测试集”。下面介绍几种常用的划分方法:
2.3.1 留出法
将数据集D划分为两个互斥的集合,一个作为训练集S,一个作为测试集T,满足D=S∪T且S∩T=∅,常见的划分为:大约2/3-4/5的样本用作训练,剩下的用作测试。需要注意的是:训练/测试集的划分要尽可能保持数据分布的一致性,以避免由于分布的差异引入额外的偏差,常见的做法是采取分层抽样。同时,由于划分的随机性,单次的留出法结果往往不够稳定,一般要采用若干次随机划分,重复实验取平均值的做法。
2.3.2 交叉验证法
将数据集D划分为k个大小相同的互斥子集,满足D=D1∪D2∪...∪Dk,Di∩Dj=∅(i≠j),同样地尽可能保持数据分布的一致性,即采用分层抽样的方法获得这些子集。交叉验证法的思想是:每次用k-1个子集的并集作为训练集,余下的那个子集作为测试集,这样就有K种训练集/测试集划分的情况,从而可进行k次训练和测试,最终返回k次测试结果的均值。交叉验证法也称“k折交叉验证”,k最常用的取值是10,下图给出了10折交叉验证的示意图。

与留出法类似,将数据集D划分为K个子集的过程具有随机性,因此K折交叉验证通常也要重复p次,称为p次k折交叉验证,常见的是10次10折交叉验证,即进行了100次训练/测试。特殊地当划分的k个子集的每个子集中只有一个样本时,称为“留一法”,显然,留一法的评估结果比较准确,但对计算机的消耗也是巨大的。
2.3.3 自助法
我们希望评估的是用整个D训练出的模型。但在留出法和交叉验证法中,由于保留了一部分样本用于测试,因此实际评估的模型所使用的训练集比D小,这必然会引入一些因训练样本规模不同而导致的估计偏差。留一法受训练样本规模变化的影响较小,但计算复杂度又太高了。“自助法”正是解决了这样的问题。
自助法的基本思想是:给定包含m个样本的数据集D,每次随机从D 中挑选一个样本,将其拷贝放入D',然后再将该样本放回初始数据集D 中,使得该样本在下次采样时仍有可能被采到。重复执行m 次,就可以得到了包含m个样本的数据集D'。可以得知在m次采样中,样本始终不被采到的概率取极限为:

这样,通过自助采样,初始样本集D中大约有36.8%的样本没有出现在D'中,于是可以将D'作为训练集,D-D'作为测试集。自助法在数据集较小,难以有效划分训练集/测试集时很有用,但由于自助法产生的数据集(随机抽样)改变了初始数据集的分布,因此引入了估计偏差。在初始数据集足够时,留出法和交叉验证法更加常用。
2.4 调参
大多数学习算法都有些参数(parameter) 需要设定,参数配置不同,学得模型的性能往往有显著差别,这就是通常所说的"参数调节"或简称"调参" (parameter tuning)。
学习算法的很多参数是在实数范围内取值,因此,对每种参数取值都训练出模型来是不可行的。常用的做法是:对每个参数选定一个范围和步长λ,这样使得学习的过程变得可行。例如:假定算法有3 个参数,每个参数仅考虑5 个候选值,这样对每一组训练/测试集就有555= 125 个模型需考察,由此可见:拿下一个参数(即经验值)对于算法人员来说是有多么的happy。
最后需要注意的是:当选定好模型和调参完成后,我们需要使用初始的数据集D重新训练模型,即让最初划分出来用于评估的测试集也被模型学习,增强模型的学习效果。用上面考试的例子来比喻:就像高中时大家每次考试完,要将考卷的题目消化掉(大多数题目都还是之前没有见过的吧?),这样即使考差了也能开心的玩耍了~
2.5 性能度量
性能度量(performance measure)是衡量模型泛化能力的评价标准,在对比不同模型的能力时,使用不同的性能度量往往会导致不同的评判结果。本节除2.5.1外,其它主要介绍分类模型的性能度量。
2.5.1 最常见的性能度量
在回归任务中,即预测连续值的问题,最常用的性能度量是“均方误差”(mean squared error),很多的经典算法都是采用了MSE作为评价函数,想必大家都十分熟悉。

在分类任务中,即预测离散值的问题,最常用的是错误率和精度,错误率是分类错误的样本数占样本总数的比例,精度则是分类正确的样本数占样本总数的比例,易知:错误率+精度=1。


2.5.2 查准率/查全率/F1
错误率和精度虽然常用,但不能满足所有的需求,例如:在推荐系统中,我们只关心推送给用户的内容用户是否感兴趣(即查准率),或者说所有用户感兴趣的内容我们推送出来了多少(即查全率)。因此,使用查准/查全率更适合描述这类问题。对于二分类问题,分类结果混淆矩阵与查准/查全率定义如下:

初次接触时,FN与FP很难正确的理解,按照惯性思维容易把FN理解成:False->Negtive,即将错的预测为错的,这样FN和TN就反了,后来找到一张图,描述得很详细,为方便理解,把这张图也贴在了下边:

正如天下没有免费的午餐,查准率和查全率是一对矛盾的度量。例如我们想让推送的内容尽可能用户全都感兴趣,那只能推送我们把握高的内容,这样就漏掉了一些用户感兴趣的内容,查全率就低了;如果想让用户感兴趣的内容都被推送,那只有将所有内容都推送上,宁可错杀一千,不可放过一个,这样查准率就很低了。
“P-R曲线”正是描述查准/查全率变化的曲线,P-R曲线定义如下:根据学习器的预测结果(一般为一个实值或概率)对测试样本进行排序,将最可能是“正例”的样本排在前面,最不可能是“正例”的排在后面,按此顺序逐个把样本作为“正例”进行预测,每次计算出当前的P值和R值,如下图所示:

P-R曲线如何评估呢?若一个学习器A的P-R曲线被另一个学习器B的P-R曲线完全包住,则称:B的性能优于A。若A和B的曲线发生了交叉,则谁的曲线下的面积大,谁的性能更优。但一般来说,曲线下的面积是很难进行估算的,所以衍生出了“平衡点”(Break-Event Point,简称BEP),即当P=R时的取值,平衡点的取值越高,性能更优。
P和R指标有时会出现矛盾的情况,这样就需要综合考虑他们,最常见的方法就是F-Measure,又称F-Score。F-Measure是P和R的加权调和平均,即:


特别地,当β=1时,也就是常见的F1度量,是P和R的调和平均,当F1较高时,模型的性能越好。


有时候我们会有多个二分类混淆矩阵,例如:多次训练或者在多个数据集上训练,那么估算全局性能的方法有两种,分为宏观和微观。简单理解,宏观就是先算出每个混淆矩阵的P值和R值,然后取得平均P值macro-P和平均R值macro-R,在算出Fβ或F1,而微观则是计算出混淆矩阵的平均TP、FP、TN、FN,接着进行计算P、R,进而求出Fβ或F1。

2.5.3 ROC与AUC
如上所述:学习器对测试样本的评估结果一般为一个实值或概率,设定一个阈值,大于阈值为正例,小于阈值为负例,因此这个实值的好坏直接决定了学习器的泛化性能,若将这些实值排序,则排序的好坏决定了学习器的性能高低。ROC曲线正是从这个角度出发来研究学习器的泛化性能,ROC曲线与P-R曲线十分类似,都是按照排序的顺序逐一按照正例预测,不同的是ROC曲线以“真正例率”(True Positive Rate,简称TPR)为横轴,纵轴为“假正例率”(False Positive Rate,简称FPR),ROC偏重研究基于测试样本评估值的排序好坏。


简单分析图像,可以得知:当FN=0时,TN也必须0,反之也成立,我们可以画一个队列,试着使用不同的截断点(即阈值)去分割队列,来分析曲线的形状,(0,0)表示将所有的样本预测为负例,(1,1)则表示将所有的样本预测为正例,(0,1)表示正例全部出现在负例之前的理想情况,(1,0)则表示负例全部出现在正例之前的最差情况。限于篇幅,这里不再论述。
现实中的任务通常都是有限个测试样本,因此只能绘制出近似ROC曲线。绘制方法:首先根据测试样本的评估值对测试样本排序,接着按照以下规则进行绘制。

同样地,进行模型的性能比较时,若一个学习器A的ROC曲线被另一个学习器B的ROC曲线完全包住,则称B的性能优于A。若A和B的曲线发生了交叉,则谁的曲线下的面积大,谁的性能更优。ROC曲线下的面积定义为AUC(Area Uder ROC Curve),不同于P-R的是,这里的AUC是可估算的,即AOC曲线下每一个小矩形的面积之和。易知:AUC越大,证明排序的质量越好,AUC为1时,证明所有正例排在了负例的前面,AUC为0时,所有的负例排在了正例的前面。

2.5.4 代价敏感错误率与代价曲线
上面的方法中,将学习器的犯错同等对待,但在现实生活中,将正例预测成假例与将假例预测成正例的代价常常是不一样的,例如:将无疾病-->有疾病只是增多了检查,但有疾病-->无疾病却是增加了生命危险。以二分类为例,由此引入了“代价矩阵”(cost matrix)。

在非均等错误代价下,我们希望的是最小化“总体代价”,这样“代价敏感”的错误率(2.5.1节介绍)为:

同样对于ROC曲线,在非均等错误代价下,演变成了“代价曲线”,代价曲线横轴是取值在[0,1]之间的正例概率代价,式中p表示正例的概率,纵轴是取值为[0,1]的归一化代价。


代价曲线的绘制很简单:设ROC曲线上一点的坐标为(TPR,FPR) ,则可相应计算出FNR,然后在代价平面上绘制一条从(0,FPR) 到(1,FNR) 的线段,线段下的面积即表示了该条件下的期望总体代价;如此将ROC 曲线土的每个点转化为代价平面上的一条线段,然后取所有线段的下界,围成的面积即为在所有条件下学习器的期望总体代价,如图所示:

在此模型的性能度量方法就介绍完了,以前一直以为均方误差和精准度就可以了,现在才发现天空如此广阔~
2.6 比较检验
在比较学习器泛化性能的过程中,统计假设检验(hypothesis test)为学习器性能比较提供了重要依据,即若A在某测试集上的性能优于B,那A学习器比B好的把握有多大。 为方便论述,本篇中都是以“错误率”作为性能度量的标准。
2.6.1 假设检验
“假设”指的是对样本总体的分布或已知分布中某个参数值的一种猜想,例如:假设总体服从泊松分布,或假设正态总体的期望u=u0。回到本篇中,我们可以通过测试获得测试错误率,但直观上测试错误率和泛化错误率相差不会太远,因此可以通过测试错误率来推测泛化错误率的分布,这就是一种假设检验。



2.6.2 交叉验证t检验

2.6.3 McNemar检验
MaNemar主要用于二分类问题,与成对t检验一样也是用于比较两个学习器的性能大小。主要思想是:若两学习器的性能相同,则A预测正确B预测错误数应等于B预测错误A预测正确数,即e01=e10,且|e01-e10|服从N(1,e01+e10)分布。

因此,如下所示的变量服从自由度为1的卡方分布,即服从标准正态分布N(0,1)的随机变量的平方和,下式只有一个变量,故自由度为1,检验的方法同上:做出假设-->求出满足显著度的临界点-->给出拒绝域-->验证假设。

2.6.4 Friedman检验与Nemenyi后续检验
上述的三种检验都只能在一组数据集上,F检验则可以在多组数据集进行多个学习器性能的比较,基本思想是在同一组数据集上,根据测试结果(例:测试错误率)对学习器的性能进行排序,赋予序值1,2,3...,相同则平分序值,如下图所示:

若学习器的性能相同,则它们的平均序值应该相同,且第i个算法的平均序值ri服从正态分布N((k+1)/2,(k+1)(k-1)/12),则有:


服从自由度为k-1和(k-1)(N-1)的F分布。下面是F检验常用的临界值:

若“H0:所有算法的性能相同”这个假设被拒绝,则需要进行后续检验,来得到具体的算法之间的差异。常用的就是Nemenyi后续检验。Nemenyi检验计算出平均序值差别的临界值域,下表是常用的qa值,若两个算法的平均序值差超出了临界值域CD,则相应的置信度1-α拒绝“两个算法性能相同”的假设。


2.7 偏差与方差
偏差-方差分解是解释学习器泛化性能的重要工具。在学习算法中,偏差指的是预测的期望值与真实值的偏差,方差则是每一次预测值与预测值得期望之间的差均方。实际上,偏差体现了学习器预测的准确度,而方差体现了学习器预测的稳定性。通过对泛化误差的进行分解,可以得到:
- 期望泛化误差=方差+偏差
- 偏差刻画学习器的拟合能力
- 方差体现学习器的稳定性
易知:方差和偏差具有矛盾性,这就是常说的偏差-方差窘境(bias-variance dilamma),随着训练程度的提升,期望预测值与真实值之间的差异越来越小,即偏差越来越小,但是另一方面,随着训练程度加大,学习算法对数据集的波动越来越敏感,方差值越来越大。换句话说:在欠拟合时,偏差主导泛化误差,而训练到一定程度后,偏差越来越小,方差主导了泛化误差。因此训练也不要贪杯,适度辄止。

三. 线性模型
谈及线性模型,其实我们很早就已经与它打过交道,还记得高中数学必修3课本中那个顽皮的“最小二乘法”吗?这就是线性模型的经典算法之一:根据给定的(x,y)点对,求出一条与这些点拟合效果最好的直线y=ax+b,之前我们利用下面的公式便可以计算出拟合直线的系数a,b(3.1中给出了具体的计算过程),从而对于一个新的x,可以预测它所对应的y值。前面我们提到:在机器学习的术语中,当预测值为连续值时,称为“回归问题”,离散值时为“分类问题”。本篇先从线性回归任务开始,接着讨论分类和多分类问题。

3.1 线性回归
线性回归问题就是试图学到一个线性模型尽可能准确地预测新样本的输出值,例如:通过历年的人口数据预测2017年人口数量。在这类问题中,往往我们会先得到一系列的有标记数据,例如:2000-->13亿...2016-->15亿,这时输入的属性只有一个,即年份;也有输入多属性的情形,假设我们预测一个人的收入,这时输入的属性值就不止一个了,例如:(学历,年龄,性别,颜值,身高,体重)-->15k。
有时这些输入的属性值并不能直接被我们的学习模型所用,需要进行相应的处理,对于连续值的属性,一般都可以被学习器所用,有时会根据具体的情形作相应的预处理,例如:归一化等;对于离散值的属性,可作下面的处理:
-
若属性值之间存在“序关系”,则可以将其转化为连续值,例如:身高属性分为“高”“中等”“矮”,可转化为数值:{1, 0.5, 0}。
-
若属性值之间不存在“序关系”,则通常将其转化为向量的形式,例如:性别属性分为“男”“女”,可转化为二维向量:{(1,0),(0,1)}。
(1)当输入属性只有一个的时候,就是最简单的情形,也就是我们高中时最熟悉的“最小二乘法”(Euclidean distance),首先计算出每个样本预测值与真实值之间的误差并求和,通过最小化均方误差MSE,使用求偏导等于零的方法计算出拟合直线y=wx+b的两个参数w和b,计算过程如下图所示:

(2)当输入属性有多个的时候,例如对于一个样本有d个属性{(x1,x2...xd),y},则y=wx+b需要写成:

通常对于多元问题,常常使用矩阵的形式来表示数据。在本问题中,将具有m个样本的数据集表示成矩阵X,将系数w与b合并成一个列向量,这样每个样本的预测值以及所有样本的均方误差最小化就可以写成下面的形式:



同样地,我们使用最小二乘法对w和b进行估计,令均方误差的求导等于0,需要注意的是,当一个矩阵的行列式不等于0时,我们才可能对其求逆,因此对于下式,我们需要考虑矩阵(X的转置*X)的行列式是否为0,若不为0,则可以求出其解,若为0,则需要使用其它的方法进行计算,书中提到了引入正则化,此处不进行深入。

另一方面,有时像上面这种原始的线性回归可能并不能满足需求,例如:y值并不是线性变化,而是在指数尺度上变化。这时我们可以采用线性模型来逼近y的衍生物,例如lny,这时衍生的线性模型如下所示,实际上就是相当于将指数曲线投影在一条直线上,如下图所示:

更一般地,考虑所有y的衍生物的情形,就得到了“广义的线性模型”(generalized linear model),其中,g(*)称为联系函数(link function)。

3.2 线性几率回归
回归就是通过输入的属性值得到一个预测值,利用上述广义线性模型的特征,是否可以通过一个联系函数,将预测值转化为离散值从而进行分类呢?线性几率回归正是研究这样的问题。对数几率引入了一个对数几率函数(logistic function),将预测值投影到0-1之间,从而将线性回归问题转化为二分类问题。


若将y看做样本为正例的概率,(1-y)看做样本为反例的概率,则上式实际上使用线性回归模型的预测结果器逼近真实标记的对数几率。因此这个模型称为“对数几率回归”(logistic regression),也有一些书籍称之为“逻辑回归”。下面使用最大似然估计的方法来计算出w和b两个参数的取值,下面只列出求解的思路,不列出具体的计算过程。


3.3 线性判别分析
线性判别分析(Linear Discriminant Analysis,简称LDA),其基本思想是:将训练样本投影到一条直线上,使得同类的样例尽可能近,不同类的样例尽可能远。如图所示:


想让同类样本点的投影点尽可能接近,不同类样本点投影之间尽可能远,即:让各类的协方差之和尽可能小,不用类之间中心的距离尽可能大。基于这样的考虑,LDA定义了两个散度矩阵。
- 类内散度矩阵(within-class scatter matrix)

- 类间散度矩阵(between-class scaltter matrix)

因此得到了LDA的最大化目标:“广义瑞利商”(generalized Rayleigh quotient)。

从而分类问题转化为最优化求解w的问题,当求解出w后,对新的样本进行分类时,只需将该样本点投影到这条直线上,根据与各个类别的中心值进行比较,从而判定出新样本与哪个类别距离最近。求解w的方法如下所示,使用的方法为λ乘子。

若将w看做一个投影矩阵,类似PCA的思想,则LDA可将样本投影到N-1维空间(N为类簇数),投影的过程使用了类别信息(标记信息),因此LDA也常被视为一种经典的监督降维技术。
3.4 多分类学习
现实中我们经常遇到不只两个类别的分类问题,即多分类问题,在这种情形下,我们常常运用“拆分”的策略,通过多个二分类学习器来解决多分类问题,即将多分类问题拆解为多个二分类问题,训练出多个二分类学习器,最后将多个分类结果进行集成得出结论。最为经典的拆分策略有三种:“一对一”(OvO)、“一对其余”(OvR)和“多对多”(MvM),核心思想与示意图如下所示。
-
OvO:给定数据集D,假定其中有N个真实类别,将这N个类别进行两两配对(一个正类/一个反类),从而产生N(N-1)/2个二分类学习器,在测试阶段,将新样本放入所有的二分类学习器中测试,得出N(N-1)个结果,最终通过投票产生最终的分类结果。
-
OvM:给定数据集D,假定其中有N个真实类别,每次取出一个类作为正类,剩余的所有类别作为一个新的反类,从而产生N个二分类学习器,在测试阶段,得出N个结果,若仅有一个学习器预测为正类,则对应的类标作为最终分类结果。
-
MvM:给定数据集D,假定其中有N个真实类别,每次取若干个类作为正类,若干个类作为反类(通过ECOC码给出,编码),若进行了M次划分,则生成了M个二分类学习器,在测试阶段(解码),得出M个结果组成一个新的码,最终通过计算海明/欧式距离选择距离最小的类别作为最终分类结果。


3.5 类别不平衡问题
类别不平衡(class-imbanlance)就是指分类问题中不同类别的训练样本相差悬殊的情况,例如正例有900个,而反例只有100个,这个时候我们就需要进行相应的处理来平衡这个问题。常见的做法有三种:
- 在训练样本较多的类别中进行“欠采样”(undersampling),比如从正例中采出100个,常见的算法有:EasyEnsemble。
- 在训练样本较少的类别中进行“过采样”(oversampling),例如通过对反例中的数据进行插值,来产生额外的反例,常见的算法有SMOTE。
- 直接基于原数据集进行学习,对预测值进行“再缩放”处理。其中再缩放也是代价敏感学习的基础。

四. 决策树
4.1 决策树基本概念
顾名思义,决策树是基于树结构来进行决策的,在网上看到一个例子十分有趣,放在这里正好合适。现想象一位捉急的母亲想要给自己的女娃介绍一个男朋友,于是有了下面的对话:
女儿:多大年纪了?
母亲:26。
女儿:长的帅不帅?
母亲:挺帅的。
女儿:收入高不?
母亲:不算很高,中等情况。
女儿:是公务员不?
母亲:是,在税务局上班呢。
女儿:那好,我去见见。
这个女孩的挑剔过程就是一个典型的决策树,即相当于通过年龄、长相、收入和是否公务员将男童鞋分为两个类别:见和不见。假设这个女孩对男人的要求是:30岁以下、长相中等以上并且是高收入者或中等以上收入的公务员,那么使用下图就能很好地表示女孩的决策逻辑(即一颗决策树)。

在上图的决策树中,决策过程的每一次判定都是对某一属性的“测试”,决策最终结论则对应最终的判定结果。一般一颗决策树包含:一个根节点、若干个内部节点和若干个叶子节点,易知:
- 每个非叶节点表示一个特征属性测试。
- 每个分支代表这个特征属性在某个值域上的输出。
- 每个叶子节点存放一个类别。
- 每个节点包含的样本集合通过属性测试被划分到子节点中,根节点包含样本全集。
4.2 决策树的构造
决策树的构造是一个递归的过程,有三种情形会导致递归返回:(1) 当前结点包含的样本全属于同一类别,这时直接将该节点标记为叶节点,并设为相应的类别;(2) 当前属性集为空,或是所有样本在所有属性上取值相同,无法划分,这时将该节点标记为叶节点,并将其类别设为该节点所含样本最多的类别;(3) 当前结点包含的样本集合为空,不能划分,这时也将该节点标记为叶节点,并将其类别设为父节点中所含样本最多的类别。算法的基本流程如下图所示:

可以看出:决策树学习的关键在于如何选择划分属性,不同的划分属性得出不同的分支结构,从而影响整颗决策树的性能。属性划分的目标是让各个划分出来的子节点尽可能地“纯”,即属于同一类别。因此下面便是介绍量化纯度的具体方法,决策树最常用的算法有三种:ID3,C4.5和CART。
4.2.1 ID3算法
ID3算法使用信息增益为准则来选择划分属性,“信息熵”(information entropy)是度量样本结合纯度的常用指标,假定当前样本集合D中第k类样本所占比例为pk,则样本集合D的信息熵定义为:

假定通过属性划分样本集D,产生了V个分支节点,v表示其中第v个分支节点,易知:分支节点包含的样本数越多,表示该分支节点的影响力越大。故可以计算出划分后相比原始数据集D获得的“信息增益”(information gain)。

信息增益越大,表示使用该属性划分样本集D的效果越好,因此ID3算法在递归过程中,每次选择最大信息增益的属性作为当前的划分属性。
4.2.2 C4.5算法
ID3算法存在一个问题,就是偏向于取值数目较多的属性,例如:如果存在一个唯一标识,这样样本集D将会被划分为|D|个分支,每个分支只有一个样本,这样划分后的信息熵为零,十分纯净,但是对分类毫无用处。因此C4.5算法使用了“增益率”(gain ratio)来选择划分属性,来避免这个问题带来的困扰。首先使用ID3算法计算出信息增益高于平均水平的候选属性,接着C4.5计算这些候选属性的增益率,增益率定义为:

4.2.3 CART算法
CART决策树使用“基尼指数”(Gini index)来选择划分属性,基尼指数反映的是从样本集D中随机抽取两个样本,其类别标记不一致的概率,因此Gini(D)越小越好,基尼指数定义如下:

进而,使用属性α划分后的基尼指数为:

4.3 剪枝处理
从决策树的构造流程中我们可以直观地看出:不管怎么样的训练集,决策树总是能很好地将各个类别分离开来,这时就会遇到之前提到过的问题:过拟合(overfitting),即太依赖于训练样本。剪枝(pruning)则是决策树算法对付过拟合的主要手段,剪枝的策略有两种如下:
- 预剪枝(prepruning):在构造的过程中先评估,再考虑是否分支。
- 后剪枝(post-pruning):在构造好一颗完整的决策树后,自底向上,评估分支的必要性。
评估指的是性能度量,即决策树的泛化性能。之前提到:可以使用测试集作为学习器泛化性能的近似,因此可以将数据集划分为训练集和测试集。预剪枝表示在构造数的过程中,对一个节点考虑是否分支时,首先计算决策树不分支时在测试集上的性能,再计算分支之后的性能,若分支对性能没有提升,则选择不分支(即剪枝)。后剪枝则表示在构造好一颗完整的决策树后,从最下面的节点开始,考虑该节点分支对模型的性能是否有提升,若无则剪枝,即将该节点标记为叶子节点,类别标记为其包含样本最多的类别。



上图分别表示不剪枝处理的决策树、预剪枝决策树和后剪枝决策树。预剪枝处理使得决策树的很多分支被剪掉,因此大大降低了训练时间开销,同时降低了过拟合的风险,但另一方面由于剪枝同时剪掉了当前节点后续子节点的分支,因此预剪枝“贪心”的本质阻止了分支的展开,在一定程度上带来了欠拟合的风险。而后剪枝则通常保留了更多的分支,因此采用后剪枝策略的决策树性能往往优于预剪枝,但其自底向上遍历了所有节点,并计算性能,训练时间开销相比预剪枝大大提升。
4.4 连续值与缺失值处理
对于连续值的属性,若每个取值作为一个分支则显得不可行,因此需要进行离散化处理,常用的方法为二分法,基本思想为:给定样本集D与连续属性α,二分法试图找到一个划分点t将样本集D在属性α上分为≤t与>t。
- 首先将α的所有取值按升序排列,所有相邻属性的均值作为候选划分点(n-1个,n为α所有的取值数目)。
- 计算每一个划分点划分集合D(即划分为两个分支)后的信息增益。
- 选择最大信息增益的划分点作为最优划分点。

现实中常会遇到不完整的样本,即某些属性值缺失。有时若简单采取剔除,则会造成大量的信息浪费,因此在属性值缺失的情况下需要解决两个问题:(1)如何选择划分属性。(2)给定划分属性,若某样本在该属性上缺失值,如何划分到具体的分支上。假定为样本集中的每一个样本都赋予一个权重,根节点中的权重初始化为1,则定义:

对于(1):通过在样本集D中选取在属性α上没有缺失值的样本子集,计算在该样本子集上的信息增益,最终的信息增益等于该样本子集划分后信息增益乘以样本子集占样本集的比重。即:

对于(2):若该样本子集在属性α上的值缺失,则将该样本以不同的权重(即每个分支所含样本比例)划入到所有分支节点中。该样本在分支节点中的权重变为:

五. 神经网络
在机器学习中,神经网络一般指的是“神经网络学习”,是机器学习与神经网络两个学科的交叉部分。所谓神经网络,目前用得最广泛的一个定义是“神经网络是由具有适应性的简单单元组成的广泛并行互连的网络,它的组织能够模拟生物神经系统对真实世界物体所做出的交互反应”。
5.1 神经元模型
神经网络中最基本的单元是神经元模型(neuron)。在生物神经网络的原始机制中,每个神经元通常都有多个树突(dendrite),一个轴突(axon)和一个细胞体(cell body),树突短而多分支,轴突长而只有一个;在功能上,树突用于传入其它神经元传递的神经冲动,而轴突用于将神经冲动传出到其它神经元,当树突或细胞体传入的神经冲动使得神经元兴奋时,该神经元就会通过轴突向其它神经元传递兴奋。神经元的生物学结构如下图所示,不得不说高中的生化知识大学忘得可是真干净...

一直沿用至今的“M-P神经元模型”正是对这一结构进行了抽象,也称“阈值逻辑单元“,其中树突对应于输入部分,每个神经元收到n个其他神经元传递过来的输入信号,这些信号通过带权重的连接传递给细胞体,这些权重又称为连接权(connection weight)。细胞体分为两部分,前一部分计算总输入值(即输入信号的加权和,或者说累积电平),后一部分先计算总输入值与该神经元阈值的差值,然后通过激活函数(activation function)的处理,产生输出从轴突传送给其它神经元。M-P神经元模型如下图所示:

与线性分类十分相似,神经元模型最理想的激活函数也是阶跃函数,即将神经元输入值与阈值的差值映射为输出值1或0,若差值大于零输出1,对应兴奋;若差值小于零则输出0,对应抑制。但阶跃函数不连续,不光滑,故在M-P神经元模型中,也采用Sigmoid函数来近似, Sigmoid函数将较大范围内变化的输入值挤压到 (0,1) 输出值范围内,所以也称为挤压函数(squashing function)。

将多个神经元按一定的层次结构连接起来,就得到了神经网络。它是一种包含多个参数的模型,比方说10个神经元两两连接,则有100个参数需要学习(每个神经元有9个连接权以及1个阈值),若将每个神经元都看作一个函数,则整个神经网络就是由这些函数相互嵌套而成。
5.2 感知机与多层网络
感知机(Perceptron)是由两层神经元组成的一个简单模型,但只有输出层是M-P神经元,即只有输出层神经元进行激活函数处理,也称为功能神经元(functional neuron);输入层只是接受外界信号(样本属性)并传递给输出层(输入层的神经元个数等于样本的属性数目),而没有激活函数。这样一来,感知机与之前线性模型中的对数几率回归的思想基本是一样的,都是通过对属性加权与另一个常数求和,再使用sigmoid函数将这个输出值压缩到0-1之间,从而解决分类问题。不同的是感知机的输出层应该可以有多个神经元,从而可以实现多分类问题,同时两个模型所用的参数估计方法十分不同。
给定训练集,则感知机的n+1个参数(n个权重+1个阈值)都可以通过学习得到。阈值Θ可以看作一个输入值固定为-1的哑结点的权重ωn+1,即假设有一个固定输入xn+1=-1的输入层神经元,其对应的权重为ωn+1,这样就把权重和阈值统一为权重的学习了。简单感知机的结构如下图所示:

感知机权重的学习规则如下:对于训练样本(x,y),当该样本进入感知机学习后,会产生一个输出值,若该输出值与样本的真实标记不一致,则感知机会对权重进行调整,若激活函数为阶跃函数,则调整的方法为(基于梯度下降法):

其中 η∈(0,1)称为学习率,可以看出感知机是通过逐个样本输入来更新权重,首先设定好初始权重(一般为随机),逐个地输入样本数据,若输出值与真实标记相同则继续输入下一个样本,若不一致则更新权重,然后再重新逐个检验,直到每个样本数据的输出值都与真实标记相同。容易看出:感知机模型总是能将训练数据的每一个样本都预测正确,和决策树模型总是能将所有训练数据都分开一样,感知机模型很容易产生过拟合问题。
由于感知机模型只有一层功能神经元,因此其功能十分有限,只能处理线性可分的问题,对于这类问题,感知机的学习过程一定会收敛(converge),因此总是可以求出适当的权值。但是对于像书上提到的异或问题,只通过一层功能神经元往往不能解决,因此要解决非线性可分问题,需要考虑使用多层功能神经元,即神经网络。多层神经网络的拓扑结构如下图所示:

在神经网络中,输入层与输出层之间的层称为隐含层或隐层(hidden layer),隐层和输出层的神经元都是具有激活函数的功能神经元。只需包含一个隐层便可以称为多层神经网络,常用的神经网络称为“多层前馈神经网络”(multi-layer feedforward neural network),该结构满足以下几个特点:
- 每层神经元与下一层神经元之间完全互连
- 神经元之间不存在同层连接
- 神经元之间不存在跨层连接

根据上面的特点可以得知:这里的“前馈”指的是网络拓扑结构中不存在环或回路,而不是指该网络只能向前传播而不能向后传播(下节中的BP神经网络正是基于前馈神经网络而增加了反馈调节机制)。神经网络的学习过程就是根据训练数据来调整神经元之间的“连接权”以及每个神经元的阈值,换句话说:神经网络所学习到的东西都蕴含在网络的连接权与阈值中。
5.3 BP神经网络算法
由上面可以得知:神经网络的学习主要蕴含在权重和阈值中,多层网络使用上面简单感知机的权重调整规则显然不够用了,BP神经网络算法即误差逆传播算法(error BackPropagation)正是为学习多层前馈神经网络而设计,BP神经网络算法是迄今为止最成功的的神经网络学习算法。
一般而言,只需包含一个足够多神经元的隐层,就能以任意精度逼近任意复杂度的连续函数[Hornik et al.,1989],故下面以训练单隐层的前馈神经网络为例,介绍BP神经网络的算法思想。

上图为一个单隐层前馈神经网络的拓扑结构,BP神经网络算法也使用梯度下降法(gradient descent),以单个样本的均方误差的负梯度方向对权重进行调节。可以看出:BP算法首先将误差反向传播给隐层神经元,调节隐层到输出层的连接权重与输出层神经元的阈值;接着根据隐含层神经元的均方误差,来调节输入层到隐含层的连接权值与隐含层神经元的阈值。BP算法基本的推导过程与感知机的推导过程原理是相同的,下面给出调整隐含层到输出层的权重调整规则的推导过程:

学习率η∈(0,1)控制着沿反梯度方向下降的步长,若步长太大则下降太快容易产生震荡,若步长太小则收敛速度太慢,一般地常把η设置为0.1,有时更新权重时会将输出层与隐含层设置为不同的学习率。BP算法的基本流程如下所示:

BP算法的更新规则是基于每个样本的预测值与真实类标的均方误差来进行权值调节,即BP算法每次更新只针对于单个样例。需要注意的是:BP算法的最终目标是要最小化整个训练集D上的累积误差,即:

如果基于累积误差最小化的更新规则,则得到了累积误差逆传播算法(accumulated error backpropagation),即每次读取全部的数据集一遍,进行一轮学习,从而基于当前的累积误差进行权值调整,因此参数更新的频率相比标准BP算法低了很多,但在很多任务中,尤其是在数据量很大的时候,往往标准BP算法会获得较好的结果。另外对于如何设置隐层神经元个数的问题,至今仍然没有好的解决方案,常使用“试错法”进行调整。
前面提到,BP神经网络强大的学习能力常常容易造成过拟合问题,有以下两种策略来缓解BP网络的过拟合问题:
- 早停:将数据分为训练集与测试集,训练集用于学习,测试集用于评估性能,若在训练过程中,训练集的累积误差降低,而测试集的累积误差升高,则停止训练。
- 引入正则化(regularization):基本思想是在累积误差函数中增加一个用于描述网络复杂度的部分,例如所有权值与阈值的平方和,其中λ∈(0,1)用于对累积经验误差与网络复杂度这两项进行折中,常通过交叉验证法来估计。

5.4 全局最小与局部最小
模型学习的过程实质上就是一个寻找最优参数的过程,例如BP算法试图通过最速下降来寻找使得累积经验误差最小的权值与阈值,在谈到最优时,一般会提到局部极小(local minimum)和全局最小(global minimum)。
- 局部极小解:参数空间中的某个点,其邻域点的误差函数值均不小于该点的误差函数值。
- 全局最小解:参数空间中的某个点,所有其他点的误差函数值均不小于该点的误差函数值。

要成为局部极小点,只要满足该点在参数空间中的梯度为零。局部极小可以有多个,而全局最小只有一个。全局最小一定是局部极小,但局部最小却不一定是全局最小。显然在很多机器学习算法中,都试图找到目标函数的全局最小。梯度下降法的主要思想就是沿着负梯度方向去搜索最优解,负梯度方向是函数值下降最快的方向,若迭代到某处的梯度为0,则表示达到一个局部最小,参数更新停止。因此在现实任务中,通常使用以下策略尽可能地去接近全局最小。
- 以多组不同参数值初始化多个神经网络,按标准方法训练,迭代停止后,取其中误差最小的解作为最终参数。
- 使用“模拟退火”技术,这里不做具体介绍。
- 使用随机梯度下降,即在计算梯度时加入了随机因素,使得在局部最小时,计算的梯度仍可能不为0,从而迭代可以继续进行。
5.5 深度学习
理论上,参数越多,模型复杂度就越高,容量(capability)就越大,从而能完成更复杂的学习任务。深度学习(deep learning)正是一种极其复杂而强大的模型。
怎么增大模型复杂度呢?两个办法,一是增加隐层的数目,二是增加隐层神经元的数目。前者更有效一些,因为它不仅增加了功能神经元的数量,还增加了激活函数嵌套的层数。但是对于多隐层神经网络,经典算法如标准BP算法往往会在误差逆传播时发散(diverge),无法收敛达到稳定状态。
那要怎么有效地训练多隐层神经网络呢?一般来说有以下两种方法:
-
无监督逐层训练(unsupervised layer-wise training):每次训练一层隐节点,把上一层隐节点的输出当作输入来训练,本层隐结点训练好后,输出再作为下一层的输入来训练,这称为预训练(pre-training)。全部预训练完成后,再对整个网络进行微调(fine-tuning)训练。一个典型例子就是深度信念网络(deep belief network,简称DBN)。这种做法其实可以视为把大量的参数进行分组,先找出每组较好的设置,再基于这些局部最优的结果来训练全局最优。
-
权共享(weight sharing):令同一层神经元使用完全相同的连接权,典型的例子是卷积神经网络(Convolutional Neural Network,简称CNN)。这样做可以大大减少需要训练的参数数目。

深度学习可以理解为一种特征学习(feature learning)或者表示学习(representation learning),无论是DBN还是CNN,都是通过多个隐层来把与输出目标联系不大的初始输入转化为与输出目标更加密切的表示,使原来只通过单层映射难以完成的任务变为可能。即通过多层处理,逐渐将初始的“低层”特征表示转化为“高层”特征表示,从而使得最后可以用简单的模型来完成复杂的学习任务。
传统任务中,样本的特征需要人类专家来设计,这称为特征工程(feature engineering)。特征好坏对泛化性能有至关重要的影响。而深度学习为全自动数据分析带来了可能,可以自动产生更好的特征。
六. 支持向量机
支持向量机是一种经典的二分类模型,基本模型定义为特征空间中最大间隔的线性分类器,其学习的优化目标便是间隔最大化,因此支持向量机本身可以转化为一个凸二次规划求解的问题。
6.1 函数间隔与几何间隔
对于二分类学习,假设现在的数据是线性可分的,这时分类学习最基本的想法就是找到一个合适的超平面,该超平面能够将不同类别的样本分开,类似二维平面使用ax+by+c=0来表示,超平面实际上表示的就是高维的平面,如下图所示:

对数据点进行划分时,易知:当超平面距离与它最近的数据点的间隔越大,分类的鲁棒性越好,即当新的数据点加入时,超平面对这些点的适应性最强,出错的可能性最小。因此需要让所选择的超平面能够最大化这个间隔Gap(如下图所示), 常用的间隔定义有两种,一种称之为函数间隔,一种为几何间隔,下面将分别介绍这两种间隔,并对SVM为什么会选用几何间隔做了一些阐述。

6.1.1 函数间隔
在超平面w'x+b=0确定的情况下,|w'x*+b|能够代表点x距离超平面的远近,易知:当w'x+b>0时,表示x在超平面的一侧(正类,类标为1),而当w'x+b<0时,则表示x在超平面的另外一侧(负类,类别为-1),因此(w'x+b)y* 的正负性恰能表示数据点x*是否被分类正确。于是便引出了函数间隔的定义(functional margin):

而超平面(w,b)关于所有样本点(Xi,Yi)的函数间隔最小值则为超平面在训练数据集T上的函数间隔:

可以看出:这样定义的函数间隔在处理SVM上会有问题,当超平面的两个参数w和b同比例改变时,函数间隔也会跟着改变,但是实际上超平面还是原来的超平面,并没有变化。例如:w1x1+w2x2+w3x3+b=0其实等价于2w1x1+2w2x2+2w3x3+2b=0,但计算的函数间隔却翻了一倍。从而引出了能真正度量点到超平面距离的概念--几何间隔(geometrical margin)。
6.1.2 几何间隔
几何间隔代表的则是数据点到超平面的真实距离,对于超平面w'x+b=0,w代表的是该超平面的法向量,设x为超平面外一点x在法向量w方向上的投影点,x与超平面的距离为r,则有x=x-r(w/||w||),又x在超平面上,即w'x+b=0,代入即可得:

为了得到r的绝对值,令r呈上其对应的类别y,即可得到几何间隔的定义:

从上述函数间隔与几何间隔的定义可以看出:实质上函数间隔就是|w'x+b|,而几何间隔就是点到超平面的距离。
6.2 最大间隔与支持向量
通过前面的分析可知:函数间隔不适合用来最大化间隔,因此这里我们要找的最大间隔指的是几何间隔,于是最大间隔分类器的目标函数定义为:

一般地,我们令r^为1(这样做的目的是为了方便推导和目标函数的优化),从而上述目标函数转化为:

对于y(w'x+b)=1的数据点,即下图中位于w'x+b=1或w'x+b=-1上的数据点,我们称之为支持向量(support vector),易知:对于所有的支持向量,它们恰好满足y*(w'x*+b)=1,而所有不是支持向量的点,有y*(w'x*+b)>1。

6.3 从原始优化问题到对偶问题
对于上述得到的目标函数,求1/||w||的最大值相当于求||w||^2的最小值,因此很容易将原来的目标函数转化为:

即变为了一个带约束的凸二次规划问题,按书上所说可以使用现成的优化计算包(QP优化包)求解,但由于SVM的特殊性,一般我们将原问题变换为它的对偶问题,接着再对其对偶问题进行求解。为什么通过对偶问题进行求解,有下面两个原因:
- 一是因为使用对偶问题更容易求解;
- 二是因为通过对偶问题求解出现了向量内积的形式,从而能更加自然地引出核函数。
对偶问题,顾名思义,可以理解成优化等价的问题,更一般地,是将一个原始目标函数的最小化转化为它的对偶函数最大化的问题。对于当前的优化问题,首先我们写出它的朗格朗日函数:

上式很容易验证:当其中有一个约束条件不满足时,L的最大值为 ∞(只需令其对应的α为 ∞即可);当所有约束条件都满足时,L的最大值为1/2||w||^2(此时令所有的α为0),因此实际上原问题等价于:

由于这个的求解问题不好做,因此一般我们将最小和最大的位置交换一下(需满足KKT条件) ,变成原问题的对偶问题:

这样就将原问题的求最小变成了对偶问题求最大(用对偶这个词还是很形象),接下来便可以先求L对w和b的极小,再求L对α的极大。
(1)首先求L对w和b的极小,分别求L关于w和b的偏导,可以得出:

将上述结果代入L得到:

(2)接着L关于α极大求解α(通过SMO算法求解,此处不做深入)。

(3)最后便可以根据求解出的α,计算出w和b,从而得到分类超平面函数。

在对新的点进行预测时,实际上就是将数据点x*代入分类函数f(x)=w'x+b中,若f(x)>0,则为正类,f(x)<0,则为负类,根据前面推导得出的w与b,分类函数如下所示,此时便出现了上面所提到的内积形式。

这里实际上只需计算新样本与支持向量的内积,因为对于非支持向量的数据点,其对应的拉格朗日乘子一定为0,根据最优化理论(K-T条件),对于不等式约束y(w'x+b)-1≥0,满足:

6.4 核函数
由于上述的超平面只能解决线性可分的问题,对于线性不可分的问题,例如:异或问题,我们需要使用核函数将其进行推广。一般地,解决线性不可分问题时,常常采用映射的方式,将低维原始空间映射到高维特征空间,使得数据集在高维空间中变得线性可分,从而再使用线性学习器分类。如果原始空间为有限维,即属性数有限,那么总是存在一个高维特征空间使得样本线性可分。若∅代表一个映射,则在特征空间中的划分函数变为:

按照同样的方法,先写出新目标函数的拉格朗日函数,接着写出其对偶问题,求L关于w和b的极大,最后运用SOM求解α。可以得出:
(1)原对偶问题变为:

(2)原分类函数变为:

求解的过程中,只涉及到了高维特征空间中的内积运算,由于特征空间的维数可能会非常大,例如:若原始空间为二维,映射后的特征空间为5维,若原始空间为三维,映射后的特征空间将是19维,之后甚至可能出现无穷维,根本无法进行内积运算了,此时便引出了核函数(Kernel)的概念。

因此,核函数可以直接计算隐式映射到高维特征空间后的向量内积,而不需要显式地写出映射后的结果,它虽然完成了将特征从低维到高维的转换,但最终却是在低维空间中完成向量内积计算,与高维特征空间中的计算等效**(低维计算,高维表现)**,从而避免了直接在高维空间无法计算的问题。引入核函数后,原来的对偶问题与分类函数则变为:
(1)对偶问题:

(2)分类函数:

因此,在线性不可分问题中,核函数的选择成了支持向量机的最大变数,若选择了不合适的核函数,则意味着将样本映射到了一个不合适的特征空间,则极可能导致性能不佳。同时,核函数需要满足以下这个必要条件:

由于核函数的构造十分困难,通常我们都是从一些常用的核函数中选择,下面列出了几种常用的核函数:

6.5 软间隔支持向量机
前面的讨论中,我们主要解决了两个问题:当数据线性可分时,直接使用最大间隔的超平面划分;当数据线性不可分时,则通过核函数将数据映射到高维特征空间,使之线性可分。然而在现实问题中,对于某些情形还是很难处理,例如数据中有噪声的情形,噪声数据(outlier)本身就偏离了正常位置,但是在前面的SVM模型中,我们要求所有的样本数据都必须满足约束,如果不要这些噪声数据还好,当加入这些outlier后导致划分超平面被挤歪了,如下图所示,对支持向量机的泛化性能造成很大的影响。

为了解决这一问题,我们需要允许某一些数据点不满足约束,即可以在一定程度上偏移超平面,同时使得不满足约束的数据点尽可能少,这便引出了**“软间隔”支持向量机**的概念
- 允许某些数据点不满足约束y(w'x+b)≥1;
- 同时又使得不满足约束的样本尽可能少。
这样优化目标变为:

如同阶跃函数,0/1损失函数虽然表示效果最好,但是数学性质不佳。因此常用其它函数作为“替代损失函数”。

支持向量机中的损失函数为hinge损失,引入**“松弛变量”**,目标函数与约束条件可以写为:

其中C为一个参数,控制着目标函数与新引入正则项之间的权重,这样显然每个样本数据都有一个对应的松弛变量,用以表示该样本不满足约束的程度,将新的目标函数转化为拉格朗日函数得到:

按照与之前相同的方法,先让L求关于w,b以及松弛变量的极小,再使用SMO求出α,有:

将w代入L化简,便得到其对偶问题:

将“软间隔”下产生的对偶问题与原对偶问题对比可以发现:新的对偶问题只是约束条件中的α多出了一个上限C,其它的完全相同,因此在引入核函数处理线性不可分问题时,便能使用与“硬间隔”支持向量机完全相同的方法。
七. 贝叶斯分类器
贝叶斯分类器是一种概率框架下的统计学习分类器,对分类任务而言,假设在相关概率都已知的情况下,贝叶斯分类器考虑如何基于这些概率为样本判定最优的类标。在开始介绍贝叶斯决策论之前,我们首先来回顾下概率论委员会常委--贝叶斯公式。

7.1 贝叶斯决策论
若将上述定义中样本空间的划分Bi看做为类标,A看做为一个新的样本,则很容易将条件概率理解为样本A是类别Bi的概率。在机器学习训练模型的过程中,往往我们都试图去优化一个风险函数,因此在概率框架下我们也可以为贝叶斯定义“条件风险”(conditional risk)。

我们的任务就是寻找一个判定准则最小化所有样本的条件风险总和,因此就有了贝叶斯判定准则(Bayes decision rule):为最小化总体风险,只需在每个样本上选择那个使得条件风险最小的类标。

若损失函数λ取0-1损失,则有:

即对于每个样本x,选择其后验概率P(c | x)最大所对应的类标,能使得总体风险函数最小,从而将原问题转化为估计后验概率P(c | x)。一般这里有两种策略来对后验概率进行估计:
- 判别式模型:直接对 P(c | x)进行建模求解。例我们前面所介绍的决策树、神经网络、SVM都是属于判别式模型。
- 生成式模型:通过先对联合分布P(x,c)建模,从而进一步求解 P(c | x)。
贝叶斯分类器就属于生成式模型,基于贝叶斯公式对后验概率P(c | x) 进行一项神奇的变换,巴拉拉能量.... P(c | x)变身:

对于给定的样本x,P(x)与类标无关,P(c)称为类先验概率,p(x | c )称为类条件概率。这时估计后验概率P(c | x)就变成为估计类先验概率和类条件概率的问题。对于先验概率和后验概率,在看这章之前也是模糊了我好久,这里普及一下它们的基本概念。
- 先验概率: 根据以往经验和分析得到的概率。
- 后验概率:后验概率是基于新的信息,修正原来的先验概率后所获得的更接近实际情况的概率估计。
实际上先验概率就是在没有任何结果出来的情况下估计的概率,而后验概率则是在有一定依据后的重新估计,直观意义上后验概率就是条件概率。下面直接上Wiki上的一个例子,简单粗暴快速完事...

回归正题,对于类先验概率P(c),p(c)就是样本空间中各类样本所占的比例,根据大数定理(当样本足够多时,频率趋于稳定等于其概率),这样当训练样本充足时,p(c)可以使用各类出现的频率来代替。因此只剩下类条件概率p(x | c ),它表达的意思是在类别c中出现x的概率,它涉及到属性的联合概率问题,若只有一个离散属性还好,当属性多时采用频率估计起来就十分困难,因此这里一般采用极大似然法进行估计。
7.2 极大似然法
极大似然估计(Maximum Likelihood Estimation,简称MLE),是一种根据数据采样来估计概率分布的经典方法。常用的策略是先假定总体具有某种确定的概率分布,再基于训练样本对概率分布的参数进行估计。运用到类条件概率p(x | c )中,假设p(x | c )服从一个参数为θ的分布,问题就变为根据已知的训练样本来估计θ。极大似然法的核心思想就是:估计出的参数使得已知样本出现的概率最大,即使得训练数据的似然最大。

所以,贝叶斯分类器的训练过程就是参数估计。总结最大似然法估计参数的过程,一般分为以下四个步骤:
- 1.写出似然函数;
- 2.对似然函数取对数,并整理;
- 3.求导数,令偏导数为0,得到似然方程组;
- 4.解似然方程组,得到所有参数即为所求。
例如:假设样本属性都是连续值,p(x | c )服从一个多维高斯分布,则通过MLE计算出的参数刚好分别为:

上述结果看起来十分合乎实际,但是采用最大似然法估计参数的效果很大程度上依赖于作出的假设是否合理,是否符合潜在的真实数据分布。这就需要大量的经验知识,搞统计越来越值钱也是这个道理,大牛们掐指一算比我们搬砖几天更有效果。
7.3 朴素贝叶斯分类器
不难看出:原始的贝叶斯分类器最大的问题在于联合概率密度函数的估计,首先需要根据经验来假设联合概率分布,其次当属性很多时,训练样本往往覆盖不够,参数的估计会出现很大的偏差。为了避免这个问题,朴素贝叶斯分类器(naive Bayes classifier)采用了“属性条件独立性假设”,即样本数据的所有属性之间相互独立。这样类条件概率p(x | c )可以改写为:

这样,为每个样本估计类条件概率变成为每个样本的每个属性估计类条件概率。

相比原始贝叶斯分类器,朴素贝叶斯分类器基于单个的属性计算类条件概率更加容易操作,需要注意的是:若某个属性值在训练集中和某个类别没有一起出现过,这样会抹掉其它的属性信息,因为该样本的类条件概率被计算为0。因此在估计概率值时,常常用进行平滑(smoothing)处理,拉普拉斯修正(Laplacian correction)就是其中的一种经典方法,具体计算方法如下:

当训练集越大时,拉普拉斯修正引入的影响越来越小。对于贝叶斯分类器,模型的训练就是参数估计,因此可以事先将所有的概率储存好,当有新样本需要判定时,直接查表计算即可。
八. EM算法
EM(Expectation-Maximization)算法是一种常用的估计参数隐变量的利器,也称为“期望最大算法”,是数据挖掘的十大经典算法之一。EM算法主要应用于训练集样本不完整即存在隐变量时的情形(例如某个属性值未知),通过其独特的“两步走”策略能较好地估计出隐变量的值。
8.1 EM算法思想
EM是一种迭代式的方法,它的基本思想就是:若样本服从的分布参数θ已知,则可以根据已观测到的训练样本推断出隐变量Z的期望值(E步),若Z的值已知则运用最大似然法估计出新的θ值(M步)。重复这个过程直到Z和θ值不再发生变化。
简单来讲:假设我们想估计A和B这两个参数,在开始状态下二者都是未知的,但如果知道了A的信息就可以得到B的信息,反过来知道了B也就得到了A。可以考虑首先赋予A某种初值,以此得到B的估计值,然后从B的当前值出发,重新估计A的取值,这个过程一直持续到收敛为止。

现在再来回想聚类的代表算法K-Means:【首先随机选择类中心=>将样本点划分到类簇中=>重新计算类中心=>不断迭代直至收敛】,不难发现这个过程和EM迭代的方法极其相似,事实上,若将样本的类别看做为“隐变量”(latent variable)Z,类中心看作样本的分布参数θ,K-Means就是通过EM算法来进行迭代的,与我们这里不同的是,K-Means的目标是最小化样本点到其对应类中心的距离和,上述为极大化似然函数。
8.2 EM算法数学推导
在上篇极大似然法中,当样本属性值都已知时,我们很容易通过极大化对数似然,接着对每个参数求偏导计算出参数的值。但当存在隐变量时,就无法直接求解,此时我们通常最大化已观察数据的对数“边际似然”(marginal likelihood)。

这时候,通过边缘似然将隐变量Z引入进来,对于参数估计,现在与最大似然不同的只是似然函数式中多了一个未知的变量Z,也就是说我们的目标是找到适合的θ和Z让L(θ)最大,这样我们也可以分别对未知的θ和Z求偏导,再令其等于0。
然而观察上式可以发现,和的对数(ln(x1+x2+x3))求导十分复杂,那能否通过变换上式得到一种求导简单的新表达式呢?这时候 Jensen不等式就派上用场了,先回顾一下高等数学凸函数的内容:
Jensen's inequality:过一个凸函数上任意两点所作割线一定在这两点间的函数图象的上方。理解起来也十分简单,对于凸函数f(x)''>0,即曲线的变化率是越来越大单调递增的,所以函数越到后面增长越厉害,这样在一个区间下,函数的均值就会大一些了。

因为ln(*)函数为凹函数,故可以将上式“和的对数”变为“对数的和”,这样就很容易求导了。

接着求解Qi和θ:首先固定θ(初始值),通过求解Qi使得J(θ,Q)在θ处与L(θ)相等,即求出L(θ)的下界;然后再固定Qi,调整θ,最大化下界J(θ,Q)。不断重复两个步骤直到稳定。通过jensen不等式的性质,Qi的计算公式实际上就是后验概率:

通过数学公式的推导,简单来理解这一过程:固定θ计算Q的过程就是在建立L(θ)的下界,即通过jenson不等式得到的下界(E步);固定Q计算θ则是使得下界极大化(M步),从而不断推高边缘似然L(θ)。从而循序渐进地计算出L(θ)取得极大值时隐变量Z的估计值。
EM算法也可以看作一种“坐标下降法”,首先固定一个值,对另外一个值求极值,不断重复直到收敛。这时候也许大家就有疑问,问什么不直接这两个家伙求偏导用梯度下降呢?这时候就是坐标下降的优势,有些特殊的函数,例如曲线函数z=y^2+x^2+x^2y+xy+...,无法直接求导,这时如果先固定其中的一个变量,再对另一个变量求极值,则变得可行。

8.3 EM算法流程
看完数学推导,算法的流程也就十分简单了,这里有两个版本,版本一来自西瓜书,周天使的介绍十分简洁;版本二来自于大牛的博客。结合着数学推导,自认为版本二更具有逻辑性,两者唯一的区别就在于版本二多出了红框的部分.
版本一:

版本二:

九. 集成学习
顾名思义,集成学习(ensemble learning)指的是将多个学习器进行有效地结合,组建一个“学习器委员会”,其中每个学习器担任委员会成员并行使投票表决权,使得委员会最后的决定更能够四方造福普度众生~...~,即其泛化性能要能优于其中任何一个学习器。
9.1 个体与集成
集成学习的基本结构为:先产生一组个体学习器,再使用某种策略将它们结合在一起。集成模型如下图所示:

在上图的集成模型中,若个体学习器都属于同一类别,例如都是决策树或都是神经网络,则称该集成为同质的(homogeneous);若个体学习器包含多种类型的学习算法,例如既有决策树又有神经网络,则称该集成为异质的(heterogenous)。
同质集成:个体学习器称为“基学习器”(base learner),对应的学习算法为“基学习算法”(base learning algorithm)。 异质集成:个体学习器称为“组件学习器”(component learner)或直称为“个体学习器”。
上面我们已经提到要让集成起来的泛化性能比单个学习器都要好,虽说团结力量大但也有木桶短板理论调皮捣蛋,那如何做到呢?这就引出了集成学习的两个重要概念:准确性和多样性(diversity)。准确性指的是个体学习器不能太差,要有一定的准确度;多样性则是个体学习器之间的输出要具有差异性。通过下面的这三个例子可以很容易看出这一点,准确度较高,差异度也较高,可以较好地提升集成性能。

现在考虑二分类的简单情形,假设基分类器之间相互独立(能提供较高的差异度),且错误率相等为 ε,则可以将集成器的预测看做一个伯努利实验,易知当所有基分类器中不足一半预测正确的情况下,集成器预测错误,所以集成器的错误率可以计算为:

此时,集成器错误率随着基分类器的个数的增加呈指数下降,但前提是基分类器之间相互独立,在实际情形中显然是不可能的,假设训练有A和B两个分类器,对于某个测试样本,显然满足:P(A=1 | B=1)> P(A=1),因为A和B为了解决相同的问题而训练,因此在预测新样本时存在着很大的联系。因此,个体学习器的“准确性”和“差异性”本身就是一对矛盾的变量,准确性高意味着牺牲多样性,所以产生“好而不同”的个体学习器正是集成学习研究的核心。现阶段有三种主流的集成学习方法:Boosting、Bagging以及随机森林(Random Forest),接下来将进行逐一介绍。
9.2 Boosting
Boosting是一种串行的工作机制,即个体学习器的训练存在依赖关系,必须一步一步序列化进行。其基本思想是:增加前一个基学习器在训练训练过程中预测错误样本的权重,使得后续基学习器更加关注这些打标错误的训练样本,尽可能纠正这些错误,一直向下串行直至产生需要的T个基学习器,Boosting最终对这T个学习器进行加权结合,产生学习器委员会。
Boosting族算法最著名、使用最为广泛的就是AdaBoost,因此下面主要是对AdaBoost算法进行介绍。AdaBoost使用的是指数损失函数,因此AdaBoost的权值与样本分布的更新都是围绕着最小化指数损失函数进行的。看到这里回想一下之前的机器学习算法,不难发现机器学习的大部分带参模型只是改变了最优化目标中的损失函数:如果是Square loss,那就是最小二乘了;如果是Hinge Loss,那就是著名的SVM了;如果是log-Loss,那就是Logistic Regression了。
定义基学习器的集成为加权结合,则有:

AdaBoost算法的指数损失函数定义为:

具体说来,整个Adaboost 迭代算法分为3步:
- 初始化训练数据的权值分布。如果有N个样本,则每一个训练样本最开始时都被赋予相同的权值:1/N。
- 训练弱分类器。具体训练过程中,如果某个样本点已经被准确地分类,那么在构造下一个训练集中,它的权值就被降低;相反,如果某个样本点没有被准确地分类,那么它的权值就得到提高。然后,权值更新过的样本集被用于训练下一个分类器,整个训练过程如此迭代地进行下去。
- 将各个训练得到的弱分类器组合成强分类器。各个弱分类器的训练过程结束后,加大分类误差率小的弱分类器的权重,使其在最终的分类函数中起着较大的决定作用,而降低分类误差率大的弱分类器的权重,使其在最终的分类函数中起着较小的决定作用。
整个AdaBoost的算法流程如下所示:

可以看出:AdaBoost的核心步骤就是计算基学习器权重和样本权重分布,那为何是上述的计算公式呢?这就涉及到了我们之前为什么说大部分带参机器学习算法只是改变了损失函数,就是因为大部分模型的参数都是通过最优化损失函数(可能还加个规则项)而计算(梯度下降,坐标下降等)得到,这里正是通过最优化指数损失函数从而得到这两个参数的计算公式,具体的推导过程此处不进行展开。
Boosting算法要求基学习器能对特定分布的数据进行学习,即每次都更新样本分布权重,这里书上提到了两种方法:“重赋权法”(re-weighting)和“重采样法”(re-sampling),书上的解释有些晦涩,这里进行展开一下:
重赋权法 : 对每个样本附加一个权重,这时涉及到样本属性与标签的计算,都需要乘上一个权值。 重采样法 : 对于一些无法接受带权样本的及学习算法,适合用“重采样法”进行处理。方法大致过程是,根据各个样本的权重,对训练数据进行重采样,初始时样本权重一样,每个样本被采样到的概率一致,每次从N个原始的训练样本中按照权重有放回采样N个样本作为训练集,然后计算训练集错误率,然后调整权重,重复采样,集成多个基学习器。
从偏差-方差分解来看:Boosting算法主要关注于降低偏差,每轮的迭代都关注于训练过程中预测错误的样本,将弱学习提升为强学习器。从AdaBoost的算法流程来看,标准的AdaBoost只适用于二分类问题。在此,当选为数据挖掘十大算法之一的AdaBoost介绍到这里,能够当选正是说明这个算法十分婀娜多姿,背后的数学证明和推导充分证明了这一点,限于篇幅不再继续展开。
9.3 Bagging与Random Forest
相比之下,Bagging与随机森林算法就简洁了许多,上面已经提到产生“好而不同”的个体学习器是集成学习研究的核心,即在保证基学习器准确性的同时增加基学习器之间的多样性。而这两种算法的基本思(tao)想(lu)都是通过“自助采样”的方法来增加多样性。
9.3.1 Bagging
Bagging是一种并行式的集成学习方法,即基学习器的训练之间没有前后顺序可以同时进行,Bagging使用“有放回”采样的方式选取训练集,对于包含m个样本的训练集,进行m次有放回的随机采样操作,从而得到m个样本的采样集,这样训练集中有接近36.8%的样本没有被采到。按照相同的方式重复进行,我们就可以采集到T个包含m个样本的数据集,从而训练出T个基学习器,最终对这T个基学习器的输出进行结合。

Bagging算法的流程如下所示:

可以看出Bagging主要通过样本的扰动来增加基学习器之间的多样性,因此Bagging的基学习器应为那些对训练集十分敏感的不稳定学习算法,例如:神经网络与决策树等。从偏差-方差分解来看,Bagging算法主要关注于降低方差,即通过多次重复训练提高稳定性。不同于AdaBoost的是,Bagging可以十分简单地移植到多分类、回归等问题。总的说起来则是:AdaBoost关注于降低偏差,而Bagging关注于降低方差。
9.3.2 随机森林
随机森林(Random Forest)是Bagging的一个拓展体,它的基学习器固定为决策树,多棵树也就组成了森林,而“随机”则在于选择划分属性的随机,随机森林在训练基学习器时,也采用有放回采样的方式添加样本扰动,同时它还引入了一种属性扰动,即在基决策树的训练过程中,在选择划分属性时,RF先从候选属性集中随机挑选出一个包含K个属性的子集,再从这个子集中选择最优划分属性,一般推荐K=log2(d)。
这样随机森林中基学习器的多样性不仅来自样本扰动,还来自属性扰动,从而进一步提升了基学习器之间的差异度。相比决策树的Bagging集成,随机森林的起始性能较差(由于属性扰动,基决策树的准确度有所下降),但随着基学习器数目的增多,随机森林往往会收敛到更低的泛化误差。同时不同于Bagging中决策树从所有属性集中选择最优划分属性,随机森林只在属性集的一个子集中选择划分属性,因此训练效率更高。

9.4 结合策略
结合策略指的是在训练好基学习器后,如何将这些基学习器的输出结合起来产生集成模型的最终输出,下面将介绍一些常用的结合策略:
9.4.1 平均法(回归问题)


易知简单平均法是加权平均法的一种特例,加权平均法可以认为是集成学习研究的基本出发点。由于各个基学习器的权值在训练中得出,一般而言,在个体学习器性能相差较大时宜使用加权平均法,在个体学习器性能相差较小时宜使用简单平均法。
9.4.2 投票法(分类问题)



绝对多数投票法(majority voting)提供了拒绝选项,这在可靠性要求很高的学习任务中是一个很好的机制。同时,对于分类任务,各个基学习器的输出值有两种类型,分别为类标记和类概率。

一些在产生类别标记的同时也生成置信度的学习器,置信度可转化为类概率使用,一般基于类概率进行结合往往比基于类标记进行结合的效果更好,需要注意的是对于异质集成,其类概率不能直接进行比较,此时需要将类概率转化为类标记输出,然后再投票。
9.4.3 学习法
学习法是一种更高级的结合策略,即学习出一种“投票”的学习器,Stacking是学习法的典型代表。Stacking的基本思想是:首先训练出T个基学习器,对于一个样本它们会产生T个输出,将这T个基学习器的输出与该样本的真实标记作为新的样本,m个样本就会产生一个m*T的样本集,来训练一个新的“投票”学习器。投票学习器的输入属性与学习算法对Stacking集成的泛化性能有很大的影响,书中已经提到:投票学习器采用类概率作为输入属性,选用多响应线性回归(MLR)一般会产生较好的效果。

9.5 多样性(diversity)
在集成学习中,基学习器之间的多样性是影响集成器泛化性能的重要因素。因此增加多样性对于集成学习研究十分重要,一般的思路是在学习过程中引入随机性,常见的做法主要是对数据样本、输入属性、输出表示、算法参数进行扰动。
数据样本扰动,即利用具有差异的数据集来训练不同的基学习器。例如:有放回自助采样法,但此类做法只对那些不稳定学习算法十分有效,例如:决策树和神经网络等,训练集的稍微改变能导致学习器的显著变动。 输入属性扰动,即随机选取原空间的一个子空间来训练基学习器。例如:随机森林,从初始属性集中抽取子集,再基于每个子集来训练基学习器。但若训练集只包含少量属性,则不宜使用属性扰动。 输出表示扰动,此类做法可对训练样本的类标稍作变动,或对基学习器的输出进行转化。 算法参数扰动,通过随机设置不同的参数,例如:神经网络中,随机初始化权重与随机设置隐含层节点数。
十. 聚类算法
聚类是一种经典的无监督学习方法,无监督学习的目标是通过对无标记训练样本的学习,发掘和揭示数据集本身潜在的结构与规律,即不依赖于训练数据集的类标记信息。聚类则是试图将数据集的样本划分为若干个互不相交的类簇,从而每个簇对应一个潜在的类别。
聚类直观上来说是将相似的样本聚在一起,从而形成一个类簇(cluster)。那首先的问题是如何来度量相似性(similarity measure)呢?这便是距离度量,在生活中我们说差别小则相似,对应到多维样本,每个样本可以对应于高维空间中的一个数据点,若它们的距离相近,我们便可以称它们相似。那接着如何来评价聚类结果的好坏呢?这便是性能度量,性能度量为评价聚类结果的好坏提供了一系列有效性指标。
10.1 距离度量
谈及距离度量,最熟悉的莫过于欧式距离了,从年头一直用到年尾的距离计算公式:即对应属性之间相减的平方和再开根号。度量距离还有其它的很多经典方法,通常它们需要满足一些基本性质:

最常用的距离度量方法是**“闵可夫斯基距离”(Minkowski distance)**:

当p=1时,闵可夫斯基距离即曼哈顿距离(Manhattan distance):

当p=2时,闵可夫斯基距离即欧氏距离(Euclidean distance):

我们知道属性分为两种:连续属性和离散属性(有限个取值)。对于连续值的属性,一般都可以被学习器所用,有时会根据具体的情形作相应的预处理,例如:归一化等;而对于离散值的属性,需要作下面进一步的处理:
若属性值之间存在序关系,则可以将其转化为连续值,例如:身高属性“高”“中等”“矮”,可转化为{1, 0.5, 0}。 若属性值之间不存在序关系,则通常将其转化为向量的形式,例如:性别属性“男”“女”,可转化为{(1,0),(0,1)}。
在进行距离度量时,易知连续属性和存在序关系的离散属性都可以直接参与计算,因为它们都可以反映一种程度,我们称其为“有序属性”;而对于不存在序关系的离散属性,我们称其为:“无序属性”,显然无序属性再使用闵可夫斯基距离就行不通了。
对于无序属性,我们一般采用VDM进行距离的计算,例如:对于离散属性的两个取值a和b,定义:

于是,在计算两个样本之间的距离时,我们可以将闵可夫斯基距离和VDM混合在一起进行计算:

若我们定义的距离计算方法是用来度量相似性,例如下面将要讨论的聚类问题,即距离越小,相似性越大,反之距离越大,相似性越小。这时距离的度量方法并不一定需要满足前面所说的四个基本性质,这样的方法称为:非度量距离(non-metric distance)。
10.2 性能度量
由于聚类算法不依赖于样本的真实类标,就不能像监督学习的分类那般,通过计算分对分错(即精确度或错误率)来评价学习器的好坏或作为学习过程中的优化目标。一般聚类有两类性能度量指标:外部指标和内部指标。
10.2.1 外部指标
即将聚类结果与某个参考模型的结果进行比较,以参考模型的输出作为标准,来评价聚类好坏。假设聚类给出的结果为λ,参考模型给出的结果是λ*,则我们将样本进行两两配对,定义:

显然a和b代表着聚类结果好坏的正能量,b和c则表示参考结果和聚类结果相矛盾,基于这四个值可以导出以下常用的外部评价指标:

10.2.2 内部指标
内部指标即不依赖任何外部模型,直接对聚类的结果进行评估,聚类的目的是想将那些相似的样本尽可能聚在一起,不相似的样本尽可能分开,直观来说:簇内高内聚紧紧抱团,簇间低耦合老死不相往来。定义:

基于上面的四个距离,可以导出下面这些常用的内部评价指标:

10.3 原型聚类
原型聚类即“基于原型的聚类”(prototype-based clustering),原型表示模板的意思,就是通过参考一个模板向量或模板分布的方式来完成聚类的过程,常见的K-Means便是基于簇中心来实现聚类,混合高斯聚类则是基于簇分布来实现聚类。
10.3.1 K-Means
K-Means的思想十分简单,首先随机指定类中心,根据样本与类中心的远近划分类簇,接着重新计算类中心,迭代直至收敛。但是其中迭代的过程并不是主观地想象得出,事实上,若将样本的类别看做为“隐变量”(latent variable),类中心看作样本的分布参数,这一过程正是通过EM算法的两步走策略而计算出,其根本的目的是为了最小化平方误差函数E:

K-Means的算法流程如下所示:

10.3.2 学习向量量化(LVQ)
LVQ也是基于原型的聚类算法,与K-Means不同的是,LVQ使用样本真实类标记辅助聚类,首先LVQ根据样本的类标记,从各类中分别随机选出一个样本作为该类簇的原型,从而组成了一个原型特征向量组,接着从样本集中随机挑选一个样本,计算其与原型向量组中每个向量的距离,并选取距离最小的原型向量所在的类簇作为它的划分结果,再与真实类标比较。
若划分结果正确,则对应原型向量向这个样本靠近一些 若划分结果不正确,则对应原型向量向这个样本远离一些
LVQ算法的流程如下所示:

10.3.3 高斯混合聚类
现在可以看出K-Means与LVQ都试图以类中心作为原型指导聚类,高斯混合聚类则采用高斯分布来描述原型。现假设每个类簇中的样本都服从一个多维高斯分布,那么空间中的样本可以看作由k个多维高斯分布混合而成。
对于多维高斯分布,其概率密度函数如下所示:

其中u表示均值向量,∑表示协方差矩阵,可以看出一个多维高斯分布完全由这两个参数所确定。接着定义高斯混合分布为:

α称为混合系数,这样空间中样本的采集过程则可以抽象为:(1)先选择一个类簇(高斯分布),(2)再根据对应高斯分布的密度函数进行采样,这时候贝叶斯公式又能大展身手了:

此时只需要选择PM最大时的类簇并将该样本划分到其中,看到这里很容易发现:这和那个传说中的贝叶斯分类不是神似吗,都是通过贝叶斯公式展开,然后计算类先验概率和类条件概率。但遗憾的是:这里没有真实类标信息,对于类条件概率,并不能像贝叶斯分类那样通过最大似然法美好地计算出来,因为这里的样本可能属于所有的类簇,这里的似然函数变为:

可以看出:简单的最大似然法根本无法求出所有的参数,这样PM也就没法计算。这里就要召唤出之前的EM大法,首先对高斯分布的参数及混合系数进行随机初始化,计算出各个PM(即γji,第i个样本属于j类),再最大化似然函数(即LL(D)分别对α、u和∑求偏导 ),对参数进行迭代更新。

高斯混合聚类的算法流程如下图所示:

10.4 密度聚类
密度聚类则是基于密度的聚类,它从样本分布的角度来考察样本之间的可连接性,并基于可连接性(密度可达)不断拓展疆域(类簇)。其中最著名的便是DBSCAN算法,首先定义以下概念:


简单来理解DBSCAN便是:找出一个核心对象所有密度可达的样本集合形成簇。首先从数据集中任选一个核心对象A,找出所有A密度可达的样本集合,将这些样本形成一个密度相连的类簇,直到所有的核心对象都遍历完。DBSCAN算法的流程如下图所示:

10.5 层次聚类
层次聚类是一种基于树形结构的聚类方法,常用的是自底向上的结合策略(AGNES算法)。假设有N个待聚类的样本,其基本步骤是:
1.初始化-->把每个样本归为一类,计算每两个类之间的距离,也就是样本与样本之间的相似度; 2.寻找各个类之间最近的两个类,把他们归为一类(这样类的总数就少了一个); 3.重新计算新生成的这个类与各个旧类之间的相似度; 4.重复2和3直到所有样本点都归为一类,结束。
可以看出其中最关键的一步就是计算两个类簇的相似度,这里有多种度量方法:
-
单链接(single-linkage):取类间最小距离。

-
全链接(complete-linkage):取类间最大距离

-
均链接(average-linkage):取类间两两的平均距离

很容易看出:单链接的包容性极强,稍微有点暧昧就当做是自己人了,全链接则是坚持到底,只要存在缺点就坚决不合并,均连接则是从全局出发顾全大局。层次聚类法的算法流程如下所示:

在此聚类算法就介绍完毕,分类/聚类都是机器学习中最常见的任务,我实验室的大Boss也是靠着聚类起家,从此走上人生事业钱途...之巅峰,在书最后的阅读材料还看见Boss的名字,所以这章也是必读不可了...
十一. 降维与度量学习
样本的特征数称为维数(dimensionality),当维数非常大时,也就是现在所说的“维数灾难”,具体表现在:在高维情形下,数据样本将变得十分稀疏,因为此时要满足训练样本为“密采样”的总体样本数目是一个触不可及的天文数字,谓可远观而不可亵玩焉...训练样本的稀疏使得其代表总体分布的能力大大减弱,从而消减了学习器的泛化能力;同时当维数很高时,计算距离也变得十分复杂,甚至连计算内积都不再容易,这也是为什么支持向量机(SVM)使用核函数**“低维计算,高维表现”**的原因。
缓解维数灾难的一个重要途径就是降维,即通过某种数学变换将原始高维空间转变到一个低维的子空间。在这个子空间中,样本的密度将大幅提高,同时距离计算也变得容易。这时也许会有疑问,这样降维之后不是会丢失原始数据的一部分信息吗?这是因为在很多实际的问题中,虽然训练数据是高维的,但是与学习任务相关也许仅仅是其中的一个低维子空间,也称为一个低维嵌入,例如:数据属性中存在噪声属性、相似属性或冗余属性等,对高维数据进行降维能在一定程度上达到提炼低维优质属性或降噪的效果。
11.1 K近邻学习
k近邻算法简称kNN(k-Nearest Neighbor),是一种经典的监督学习方法,同时也实力担当入选数据挖掘十大算法。其工作机制十分简单粗暴:给定某个测试样本,kNN基于某种距离度量在训练集中找出与其距离最近的k个带有真实标记的训练样本,然后给基于这k个邻居的真实标记来进行预测,类似于前面集成学习中所讲到的基学习器结合策略:分类任务采用投票法,回归任务则采用平均法。接下来本篇主要就kNN分类进行讨论。

从上图【来自Wiki】中我们可以看到,图中有两种类型的样本,一类是蓝色正方形,另一类是红色三角形。而那个绿色圆形是我们待分类的样本。基于kNN算法的思路,我们很容易得到以下结论:
如果K=3,那么离绿色点最近的有2个红色三角形和1个蓝色的正方形,这3个点投票,于是绿色的这个待分类点属于红色的三角形。 如果K=5,那么离绿色点最近的有2个红色三角形和3个蓝色的正方形,这5个点投票,于是绿色的这个待分类点属于蓝色的正方形。
可以发现:kNN虽然是一种监督学习方法,但是它却没有显式的训练过程,而是当有新样本需要预测时,才来计算出最近的k个邻居,因此kNN是一种典型的懒惰学习方法,再来回想一下朴素贝叶斯的流程,训练的过程就是参数估计,因此朴素贝叶斯也可以懒惰式学习,此类技术在训练阶段开销为零,待收到测试样本后再进行计算。相应地我们称那些一有训练数据立马开工的算法为“急切学习”,可见前面我们学习的大部分算法都归属于急切学习。
很容易看出:kNN算法的核心在于k值的选取以及距离的度量。k值选取太小,模型很容易受到噪声数据的干扰,例如:极端地取k=1,若待分类样本正好与一个噪声数据距离最近,就导致了分类错误;若k值太大, 则在更大的邻域内进行投票,此时模型的预测能力大大减弱,例如:极端取k=训练样本数,就相当于模型根本没有学习,所有测试样本的预测结果都是一样的。一般地我们都通过交叉验证法来选取一个适当的k值。

对于距离度量,不同的度量方法得到的k个近邻不尽相同,从而对最终的投票结果产生了影响,因此选择一个合适的距离度量方法也十分重要。在上一篇聚类算法中,在度量样本相似性时介绍了常用的几种距离计算方法,包括闵可夫斯基距离,曼哈顿距离,VDM等。在实际应用中,kNN的距离度量函数一般根据样本的特性来选择合适的距离度量,同时应对数据进行去量纲/归一化处理来消除大量纲属性的强权政治影响。
11.2 MDS算法
不管是使用核函数升维还是对数据降维,我们都希望原始空间样本点之间的距离在新空间中基本保持不变,这样才不会使得原始空间样本之间的关系及总体分布发生较大的改变。**“多维缩放”(MDS)**正是基于这样的思想,MDS要求原始空间样本之间的距离在降维后的低维空间中得以保持。
假定m个样本在原始空间中任意两两样本之间的距离矩阵为D∈R(m*m),我们的目标便是获得样本在低维空间中的表示Z∈R(d'*m , d'< d),且任意两个样本在低维空间中的欧式距离等于原始空间中的距离,即||zi-zj||=Dist(ij)。因此接下来我们要做的就是根据已有的距离矩阵D来求解出降维后的坐标矩阵Z。

令降维后的样本坐标矩阵Z被中心化,中心化是指将每个样本向量减去整个样本集的均值向量,故所有样本向量求和得到一个零向量。这样易知:矩阵B的每一列以及每一列求和均为0,因为提取公因子后都有一项为所有样本向量的和向量。

根据上面矩阵B的特征,我们很容易得到等式(2)、(3)以及(4):

这时根据(1)--(4)式我们便可以计算出bij,即bij=(1)-(2)(1/m)-(3)(1/m)+(4)*(1/(m^2)),再逐一地计算每个b(ij),就得到了降维后低维空间中的内积矩阵B(B=Z'*Z),只需对B进行特征值分解便可以得到Z。MDS的算法流程如下图所示:

11.3 主成分分析(PCA)
不同于MDS采用距离保持的方法,主成分分析(PCA)直接通过一个线性变换,将原始空间中的样本投影到新的低维空间中。简单来理解这一过程便是:PCA采用一组新的基来表示样本点,其中每一个基向量都是原来基向量的线性组合,通过使用尽可能少的新基向量来表出样本,从而达到降维的目的。
假设使用d'个新基向量来表示原来样本,实质上是将样本投影到一个由d'个基向量确定的一个超平面上(即舍弃了一些维度),要用一个超平面对空间中所有高维样本进行恰当的表达,最理想的情形是:若这些样本点都能在超平面上表出且这些表出在超平面上都能够很好地分散开来。但是一般使用较原空间低一些维度的超平面来做到这两点十分不容易,因此我们退一步海阔天空,要求这个超平面应具有如下两个性质:
最近重构性:样本点到超平面的距离足够近,即尽可能在超平面附近; 最大可分性:样本点在超平面上的投影尽可能地分散开来,即投影后的坐标具有区分性。
这里十分神奇的是:最近重构性与最大可分性虽然从不同的出发点来定义优化问题中的目标函数,但最终这两种特性得到了完全相同的优化问题:

接着使用拉格朗日乘子法求解上面的优化问题,得到:

因此只需对协方差矩阵进行特征值分解即可求解出W,PCA算法的整个流程如下图所示:

另一篇博客给出更通俗更详细的理解:主成分分析解析(基于最大方差理论)
11.4 核化线性降维
说起机器学习你中有我/我中有你/水乳相融...在这里能够得到很好的体现。正如SVM在处理非线性可分时,通过引入核函数将样本投影到高维特征空间,接着在高维空间再对样本点使用超平面划分。这里也是相同的问题:若我们的样本数据点本身就不是线性分布,那还如何使用一个超平面去近似表出呢?因此也就引入了核函数,即先将样本映射到高维空间,再在高维空间中使用线性降维的方法。下面主要介绍**核化主成分分析(KPCA)**的思想。
若核函数的形式已知,即我们知道如何将低维的坐标变换为高维坐标,这时我们只需先将数据映射到高维特征空间,再在高维空间中运用PCA即可。但是一般情况下,我们并不知道核函数具体的映射规则,例如:Sigmoid、高斯核等,我们只知道如何计算高维空间中的样本内积,这时就引出了KPCA的一个重要创新之处:即空间中的任一向量,都可以由该空间中的所有样本线性表示。证明过程也十分简单:

这样我们便可以将高维特征空间中的投影向量wi使用所有高维样本点线性表出,接着代入PCA的求解问题,得到:

化简到最后一步,发现结果十分的美妙,只需对核矩阵K进行特征分解,便可以得出投影向量wi对应的系数向量α,因此选取特征值前d'大对应的特征向量便是d'个系数向量。这时对于需要降维的样本点,只需按照以下步骤便可以求出其降维后的坐标。可以看出:KPCA在计算降维后的坐标表示时,需要与所有样本点计算核函数值并求和,因此该算法的计算开销十分大。

11.5 流形学习
流形学习(manifold learning)是一种借助拓扑流形概念的降维方法,流形是指在局部与欧式空间同胚的空间,即在局部与欧式空间具有相同的性质,能用欧氏距离计算样本之间的距离。这样即使高维空间的分布十分复杂,但是在局部上依然满足欧式空间的性质,基于流形学习的降维正是这种**“邻域保持”的思想。其中等度量映射(Isomap)试图在降维前后保持邻域内样本之间的距离,而局部线性嵌入(LLE)则是保持邻域内样本之间的线性关系**,下面将分别对这两种著名的流行学习方法进行介绍。
11.5.1 等度量映射(Isomap)
等度量映射的基本出发点是:高维空间中的直线距离具有误导性,因为有时高维空间中的直线距离在低维空间中是不可达的。因此利用流形在局部上与欧式空间同胚的性质,可以使用近邻距离来逼近测地线距离,即对于一个样本点,它与近邻内的样本点之间是可达的,且距离使用欧式距离计算,这样整个样本空间就形成了一张近邻图,高维空间中两个样本之间的距离就转为最短路径问题。可采用著名的Dijkstra算法或Floyd算法计算最短距离,得到高维空间中任意两点之间的距离后便可以使用MDS算法来其计算低维空间中的坐标。

从MDS算法的描述中我们可以知道:MDS先求出了低维空间的内积矩阵B,接着使用特征值分解计算出了样本在低维空间中的坐标,但是并没有给出通用的投影向量w,因此对于需要降维的新样本无从下手,书中给出的权宜之计是利用已知高/低维坐标的样本作为训练集学习出一个“投影器”,便可以用高维坐标预测出低维坐标。Isomap算法流程如下图:

对于近邻图的构建,常用的有两种方法:一种是指定近邻点个数,像kNN一样选取k个最近的邻居;另一种是指定邻域半径,距离小于该阈值的被认为是它的近邻点。但两种方法均会出现下面的问题:
若邻域范围指定过大,则会造成“短路问题”,即本身距离很远却成了近邻,将距离近的那些样本扼杀在摇篮。 若邻域范围指定过小,则会造成“断路问题”,即有些样本点无法可达了,整个世界村被划分为互不可达的小部落。
11.5.2 局部线性嵌入(LLE)
不同于Isomap算法去保持邻域距离,LLE算法试图去保持邻域内的线性关系,假定样本xi的坐标可以通过它的邻域样本线性表出:


LLE算法分为两步走,首先第一步根据近邻关系计算出所有样本的邻域重构系数w:

接着根据邻域重构系数不变,去求解低维坐标:

这样利用矩阵M,优化问题可以重写为:

M特征值分解后最小的d'个特征值对应的特征向量组成Z,LLE算法的具体流程如下图所示:

11.6 度量学习
本篇一开始就提到维数灾难,即在高维空间进行机器学习任务遇到样本稀疏、距离难计算等诸多的问题,因此前面讨论的降维方法都试图将原空间投影到一个合适的低维空间中,接着在低维空间进行学习任务从而产生较好的性能。事实上,不管高维空间还是低维空间都潜在对应着一个距离度量,那可不可以直接学习出一个距离度量来等效降维呢?例如:咋们就按照降维后的方式来进行距离的计算,这便是度量学习的初衷。
首先要学习出距离度量必须先定义一个合适的距离度量形式。对两个样本xi与xj,它们之间的平方欧式距离为:

若各个属性重要程度不一样即都有一个权重,则得到加权的平方欧式距离:

此时各个属性之间都是相互独立无关的,但现实中往往会存在属性之间有关联的情形,例如:身高和体重,一般人越高,体重也会重一些,他们之间存在较大的相关性。这样计算距离就不能分属性单独计算,于是就引入经典的马氏距离(Mahalanobis distance):

标准的马氏距离中M是协方差矩阵的逆,马氏距离是一种考虑属性之间相关性且尺度无关(即无须去量纲)的距离度量。

矩阵M也称为“度量矩阵”,为保证距离度量的非负性与对称性,M必须为(半)正定对称矩阵,这样就为度量学习定义好了距离度量的形式,换句话说:度量学习便是对度量矩阵进行学习。现在来回想一下前面我们接触的机器学习不难发现:机器学习算法几乎都是在优化目标函数,从而求解目标函数中的参数。同样对于度量学习,也需要设置一个优化目标,书中简要介绍了错误率和相似性两种优化目标,此处限于篇幅不进行展开。
在此,降维和度量学习就介绍完毕。降维是将原高维空间嵌入到一个合适的低维子空间中,接着在低维空间中进行学习任务;度量学习则是试图去学习出一个距离度量来等效降维的效果,两者都是为了解决维数灾难带来的诸多问题。也许大家最后心存疑惑,那kNN呢,为什么一开头就说了kNN算法,但是好像和后面没有半毛钱关系?正是因为在降维算法中,低维子空间的维数d'通常都由人为指定,因此我们需要使用一些低开销的学习器来选取合适的d',kNN这家伙懒到家了根本无心学习,在训练阶段开销为零,测试阶段也只是遍历计算了距离,因此拿kNN来进行交叉验证就十分有优势了~同时降维后样本密度增大同时距离计算变易,更为kNN来展示它独特的十八般手艺提供了用武之地。
十二. 特征选择与稀疏学习
最近在看论文的过程中,发现对于数据集行和列的叫法颇有不同,故在介绍本篇之前,决定先将最常用的术语罗列一二,以后再见到了不管它脚扑朔还是眼迷离就能一眼识破真身了~对于数据集中的一个对象及组成对象的零件元素:
统计学家常称它们为观测(observation)和变量(variable); 数据库分析师则称其为记录(record)和字段(field); 数据挖掘/机器学习学科的研究者则习惯把它们叫做样本/示例(example/instance)和属性/特征(attribute/feature)。
回归正题,在机器学习中特征选择是一个重要的“数据预处理”(data preprocessing)过程,即试图从数据集的所有特征中挑选出与当前学习任务相关的特征子集,接着再利用数据子集来训练学习器;稀疏学习则是围绕着稀疏矩阵的优良性质,来完成相应的学习任务。
12.1 子集搜索与评价
一般地,我们可以用很多属性/特征来描述一个示例,例如对于一个人可以用性别、身高、体重、年龄、学历、专业、是否吃货等属性来描述,那现在想要训练出一个学习器来预测人的收入。根据生活经验易知:并不是所有的特征都与学习任务相关,例如年龄/学历/专业可能很大程度上影响了收入,身高/体重这些外貌属性也有较小的可能性影响收入,但像是否是一个地地道道的吃货这种属性就八杆子打不着了。因此我们只需要那些与学习任务紧密相关的特征,特征选择便是从给定的特征集合中选出相关特征子集的过程。
与上篇中降维技术有着异曲同工之处的是,特征选择也可以有效地解决维数灾难的难题。具体而言:降维从一定程度起到了提炼优质低维属性和降噪的效果,特征选择则是直接剔除那些与学习任务无关的属性而选择出最佳特征子集。若直接遍历所有特征子集,显然当维数过多时遭遇指数爆炸就行不通了;若采取从候选特征子集中不断迭代生成更优候选子集的方法,则时间复杂度大大减小。这时就涉及到了两个关键环节:1.如何生成候选子集;2.如何评价候选子集的好坏,这便是早期特征选择的常用方法。书本上介绍了贪心算法,分为三种策略:
前向搜索:初始将每个特征当做一个候选特征子集,然后从当前所有的候选子集中选择出最佳的特征子集;接着在上一轮选出的特征子集中添加一个新的特征,同样地选出最佳特征子集;最后直至选不出比上一轮更好的特征子集。 后向搜索:初始将所有特征作为一个候选特征子集;接着尝试去掉上一轮特征子集中的一个特征并选出当前最优的特征子集;最后直到选不出比上一轮更好的特征子集。 双向搜索:将前向搜索与后向搜索结合起来,即在每一轮中既有添加操作也有剔除操作。
对于特征子集的评价,书中给出了一些想法及基于信息熵的方法。假设数据集的属性皆为离散属性,这样给定一个特征子集,便可以通过这个特征子集的取值将数据集合划分为V个子集。例如:A1={男,女},A2={本科,硕士}就可以将原数据集划分为2*2=4个子集,其中每个子集的取值完全相同。这时我们就可以像决策树选择划分属性那样,通过计算信息增益来评价该属性子集的好坏。

此时,信息增益越大表示该属性子集包含有助于分类的特征越多,使用上述这种子集搜索与子集评价相结合的机制,便可以得到特征选择方法。值得一提的是若将前向搜索策略与信息增益结合在一起,与前面我们讲到的ID3决策树十分地相似。事实上,决策树也可以用于特征选择,树节点划分属性组成的集合便是选择出的特征子集。
12.2 过滤式选择(Relief)
过滤式方法是一种将特征选择与学习器训练相分离的特征选择技术,即首先将相关特征挑选出来,再使用选择出的数据子集来训练学习器。Relief是其中著名的代表性算法,它使用一个“相关统计量”来度量特征的重要性,该统计量是一个向量,其中每个分量代表着相应特征的重要性,因此我们最终可以根据这个统计量各个分量的大小来选择出合适的特征子集。
易知Relief算法的核心在于如何计算出该相关统计量。对于数据集中的每个样例xi,Relief首先找出与xi同类别的最近邻与不同类别的最近邻,分别称为猜中近邻(near-hit)与猜错近邻(near-miss),接着便可以分别计算出相关统计量中的每个分量。对于j分量:

直观上理解:对于猜中近邻,两者j属性的距离越小越好,对于猜错近邻,j属性距离越大越好。更一般地,若xi为离散属性,diff取海明距离,即相同取0,不同取1;若xi为连续属性,则diff为曼哈顿距离,即取差的绝对值。分别计算每个分量,最终取平均便得到了整个相关统计量。
标准的Relief算法只用于二分类问题,后续产生的拓展变体Relief-F则解决了多分类问题。对于j分量,新的计算公式如下:

其中pl表示第l类样本在数据集中所占的比例,易知两者的不同之处在于:标准Relief 只有一个猜错近邻,而Relief-F有多个猜错近邻。
12.3 包裹式选择(LVW)
与过滤式选择不同的是,包裹式选择将后续的学习器也考虑进来作为特征选择的评价准则。因此包裹式选择可以看作是为某种学习器量身定做的特征选择方法,由于在每一轮迭代中,包裹式选择都需要训练学习器,因此在获得较好性能的同时也产生了较大的开销。下面主要介绍一种经典的包裹式特征选择方法 --LVW(Las Vegas Wrapper),它在拉斯维加斯框架下使用随机策略来进行特征子集的搜索。拉斯维加斯?怎么听起来那么耳熟,不是那个声名显赫的赌场吗?歪果仁真会玩。怀着好奇科普一下,结果又顺带了一个赌场:
蒙特卡罗算法:采样越多,越近似最优解,一定会给出解,但给出的解不一定是正确解; 拉斯维加斯算法:采样越多,越有机会找到最优解,不一定会给出解,且给出的解一定是正确解。
举个例子,假如筐里有100个苹果,让我每次闭眼拿1个,挑出最大的。于是我随机拿1个,再随机拿1个跟它比,留下大的,再随机拿1个……我每拿一次,留下的苹果都至少不比上次的小。拿的次数越多,挑出的苹果就越大,但我除非拿100次,否则无法肯定挑出了最大的。这个挑苹果的算法,就属于蒙特卡罗算法——尽量找较好的,但不保证是最好的。
而拉斯维加斯算法,则是另一种情况。假如有一把锁,给我100把钥匙,只有1把是对的。于是我每次随机拿1把钥匙去试,打不开就再换1把。我试的次数越多,打开(正确解)的机会就越大,但在打开之前,那些错的钥匙都是没有用的。这个试钥匙的算法,就是拉斯维加斯的——尽量找最好的,但不保证能找到。
LVW算法的具体流程如下所示,其中比较特别的是停止条件参数T的设置,即在每一轮寻找最优特征子集的过程中,若随机T次仍没找到,算法就会停止,从而保证了算法运行时间的可行性。

12.4 嵌入式选择与正则化
前面提到了的两种特征选择方法:过滤式中特征选择与后续学习器完全分离,包裹式则是使用学习器作为特征选择的评价准则;嵌入式是一种将特征选择与学习器训练完全融合的特征选择方法,即将特征选择融入学习器的优化过程中。在之前《经验风险与结构风险》中已经提到:经验风险指的是模型与训练数据的契合度,结构风险则是模型的复杂程度,机器学习的核心任务就是:在模型简单的基础上保证模型的契合度。例如:岭回归就是加上了L2范数的最小二乘法,有效地解决了奇异矩阵、过拟合等诸多问题,下面的嵌入式特征选择则是在损失函数后加上了L1范数。

L1范数美名又约Lasso Regularization,指的是向量中每个元素的绝对值之和,这样在优化目标函数的过程中,就会使得w尽可能地小,在一定程度上起到了防止过拟合的作用,同时与L2范数(Ridge Regularization )不同的是,L1范数会使得部分w变为0, 从而达到了特征选择的效果。
总的来说:L1范数会趋向产生少量的特征,其他特征的权值都是0;L2会选择更多的特征,这些特征的权值都会接近于0。这样L1范数在特征选择上就十分有用,而L2范数则具备较强的控制过拟合能力。可以从下面两个方面来理解:
(1)下降速度:L1范数按照绝对值函数来下降,L2范数按照二次函数来下降。因此在0附近,L1范数的下降速度大于L2范数,故L1范数能很快地下降到0,而L2范数在0附近的下降速度非常慢,因此较大可能收敛在0的附近。

(2)空间限制:L1范数与L2范数都试图在最小化损失函数的同时,让权值W也尽可能地小。我们可以将原优化问题看做为下面的问题,即让后面的规则则都小于某个阈值。这样从图中可以看出:L1范数相比L2范数更容易得到稀疏解。


12.5 稀疏表示与字典学习
当样本数据是一个稀疏矩阵时,对学习任务来说会有不少的好处,例如很多问题变得线性可分,储存更为高效等。这便是稀疏表示与字典学习的基本出发点。稀疏矩阵即矩阵的每一行/列中都包含了大量的零元素,且这些零元素没有出现在同一行/列,对于一个给定的稠密矩阵,若我们能通过某种方法找到其合适的稀疏表示,则可以使得学习任务更加简单高效,我们称之为稀疏编码(sparse coding)或字典学习(dictionary learning)。
给定一个数据集,字典学习/稀疏编码指的便是通过一个字典将原数据转化为稀疏表示,因此最终的目标就是求得字典矩阵B及稀疏表示α,书中使用变量交替优化的策略能较好地求得解,深感陷进去短时间无法自拔,故先不进行深入...

12.6 压缩感知
压缩感知在前些年也是风风火火,与特征选择、稀疏表示不同的是:它关注的是通过欠采样信息来恢复全部信息。在实际问题中,为了方便传输和存储,我们一般将数字信息进行压缩,这样就有可能损失部分信息,如何根据已有的信息来重构出全部信号,这便是压缩感知的来历,压缩感知的前提是已知的信息具有稀疏表示。下面是关于压缩感知的一些背景:

十三. 计算学习理论
计算学习理论(computational learning theory)是通过“计算”来研究机器学习的理论,简而言之,其目的是分析学习任务的本质,例如:在什么条件下可进行有效的学习,需要多少训练样本能获得较好的精度等,从而为机器学习算法提供理论保证。
首先我们回归初心,再来谈谈经验误差和泛化误差。假设给定训练集D,其中所有的训练样本都服从一个未知的分布T,且它们都是在总体分布T中独立采样得到,即独立同分布(independent and identically distributed,i.i.d.),在《贝叶斯分类器》中我们已经提到:独立同分布是很多统计学习算法的基础假设,例如最大似然法,贝叶斯分类器,高斯混合聚类等,简单来理解独立同分布:每个样本都是从总体分布中独立采样得到,而没有拖泥带水。例如现在要进行问卷调查,要从总体人群中随机采样,看到一个美女你高兴地走过去,结果她男票突然冒了出来,说道:you jump,i jump,于是你本来只想调查一个人结果被强行撒了一把狗粮得到两份问卷,这样这两份问卷就不能称为独立同分布了,因为它们的出现具有强相关性。
回归正题,泛化误差指的是学习器在总体上的预测误差,经验误差则是学习器在某个特定数据集D上的预测误差。在实际问题中,往往我们并不能得到总体且数据集D是通过独立同分布采样得到的,因此我们常常使用经验误差作为泛化误差的近似。

13.1 PAC学习
在高中课本中,我们将函数定义为:从自变量到因变量的一种映射;对于机器学习算法,学习器也正是为了寻找合适的映射规则,即如何从条件属性得到目标属性。从样本空间到标记空间存在着很多的映射,我们将每个映射称之为概念(concept),定义:
若概念c对任何样本x满足c(x)=y,则称c为目标概念,即最理想的映射,所有的目标概念构成的集合称为**“概念类”; 给定学习算法,它所有可能映射/概念的集合称为“假设空间”,其中单个的概念称为“假设”(hypothesis); 若一个算法的假设空间包含目标概念,则称该数据集对该算法是可分**(separable)的,亦称一致(consistent)的; 若一个算法的假设空间不包含目标概念,则称该数据集对该算法是不可分(non-separable)的,或称不一致(non-consistent)的。
举个简单的例子:对于非线性分布的数据集,若使用一个线性分类器,则该线性分类器对应的假设空间就是空间中所有可能的超平面,显然假设空间不包含该数据集的目标概念,所以称数据集对该学习器是不可分的。给定一个数据集D,我们希望模型学得的假设h尽可能地与目标概念一致,这便是概率近似正确 (Probably Approximately Correct,简称PAC)的来源,即以较大的概率学得模型满足误差的预设上限。




上述关于PAC的几个定义层层相扣:定义12.1表达的是对于某种学习算法,如果能以一个置信度学得假设满足泛化误差的预设上限,则称该算法能PAC辨识概念类,即该算法的输出假设已经十分地逼近目标概念。定义12.2则将样本数量考虑进来,当样本超过一定数量时,学习算法总是能PAC辨识概念类,则称概念类为PAC可学习的。定义12.3将学习器运行时间也考虑进来,若运行时间为多项式时间,则称PAC学习算法。
显然,PAC学习中的一个关键因素就是假设空间的复杂度,对于某个学习算法,若假设空间越大,则其中包含目标概念的可能性也越大,但同时找到某个具体概念的难度也越大,一般假设空间分为有限假设空间与无限假设空间。
13.2 有限假设空间
13.2.1 可分情形
可分或一致的情形指的是:目标概念包含在算法的假设空间中。对于目标概念,在训练集D中的经验误差一定为0,因此首先我们可以想到的是:不断地剔除那些出现预测错误的假设,直到找到经验误差为0的假设即为目标概念。但由于样本集有限,可能会出现多个假设在D上的经验误差都为0,因此问题转化为:需要多大规模的数据集D才能让学习算法以置信度的概率从这些经验误差都为0的假设中找到目标概念的有效近似。

通过上式可以得知:对于可分情形的有限假设空间,目标概念都是PAC可学习的,即当样本数量满足上述条件之后,在与训练集一致的假设中总是可以在1-σ概率下找到目标概念的有效近似。
13.2.2 不可分情形
不可分或不一致的情形指的是:目标概念不存在于假设空间中,这时我们就不能像可分情形时那样从假设空间中寻找目标概念的近似。但当假设空间给定时,必然存一个假设的泛化误差最小,若能找出此假设的有效近似也不失为一个好的目标,这便是不可知学习(agnostic learning)的来源。

这时候便要用到Hoeffding不等式:

对于假设空间中的所有假设,出现泛化误差与经验误差之差大于e的概率和为:

因此,可令不等式的右边小于(等于)σ,便可以求出满足泛化误差与经验误差相差小于e所需的最少样本数,同时也可以求出泛化误差界。

13.3 VC维
现实中的学习任务通常都是无限假设空间,例如d维实数域空间中所有的超平面等,因此要对此种情形进行可学习研究,需要度量假设空间的复杂度。这便是VC维(Vapnik-Chervonenkis dimension)的来源。在介绍VC维之前,需要引入两个概念:
增长函数:对于给定数据集D,假设空间中的每个假设都能对数据集的样本赋予标记,因此一个假设对应着一种打标结果,不同假设对D的打标结果可能是相同的,也可能是不同的。随着样本数量m的增大,假设空间对样本集D的打标结果也会增多,增长函数则表示假设空间对m个样本的数据集D打标的最大可能结果数,因此增长函数描述了假设空间的表示能力与复杂度。
打散:例如对二分类问题来说,m个样本最多有2^m个可能结果,每种可能结果称为一种**“对分”**,若假设空间能实现数据集D的所有对分,则称数据集能被该假设空间打散。
因此尽管假设空间是无限的,但它对特定数据集打标的不同结果数是有限的,假设空间的VC维正是它能打散的最大数据集大小。通常这样来计算假设空间的VC维:若存在大小为d的数据集能被假设空间打散,但不存在任何大小为d+1的数据集能被假设空间打散,则其VC维为d。

同时书中给出了假设空间VC维与增长函数的两个关系:

直观来理解(1)式也十分容易: 首先假设空间的VC维是d,说明当m<=d时,增长函数与2^m相等,例如:当m=d时,右边的组合数求和刚好等于2^d;而当m=d+1时,右边等于2^(d+1)-1,十分符合VC维的定义,同时也可以使用数学归纳法证明;(2)式则是由(1)式直接推导得出。
在有限假设空间中,根据Hoeffding不等式便可以推导得出学习算法的泛化误差界;但在无限假设空间中,由于假设空间的大小无法计算,只能通过增长函数来描述其复杂度,因此无限假设空间中的泛化误差界需要引入增长函数。


上式给出了基于VC维的泛化误差界,同时也可以计算出满足条件需要的样本数(样本复杂度)。若学习算法满足经验风险最小化原则(ERM),即学习算法的输出假设h在数据集D上的经验误差最小,可证明:任何VC维有限的假设空间都是(不可知)PAC可学习的,换而言之:若假设空间的最小泛化误差为0即目标概念包含在假设空间中,则是PAC可学习,若最小泛化误差不为0,则称为不可知PAC可学习。
13.4 稳定性
稳定性考察的是当算法的输入发生变化时,输出是否会随之发生较大的变化,输入的数据集D有以下两种变化:

若对数据集中的任何样本z,满足:

即原学习器和剔除一个样本后生成的学习器对z的损失之差保持β稳定,称学习器关于损失函数满足β-均匀稳定性。同时若损失函数有上界,即原学习器对任何样本的损失函数不超过M,则有如下定理:

事实上,若学习算法符合经验风险最小化原则(ERM)且满足β-均匀稳定性,则假设空间是可学习的。稳定性通过损失函数与假设空间的可学习联系在了一起,区别在于:假设空间关注的是经验误差与泛化误差,需要考虑到所有可能的假设;而稳定性只关注当前的输出假设。
十四. 半监督学习
前面我们一直围绕的都是监督学习与无监督学习,监督学习指的是训练样本包含标记信息的学习任务,例如:常见的分类与回归算法;无监督学习则是训练样本不包含标记信息的学习任务,例如:聚类算法。在实际生活中,常常会出现一部分样本有标记和较多样本无标记的情形,例如:做网页推荐时需要让用户标记出感兴趣的网页,但是少有用户愿意花时间来提供标记。若直接丢弃掉无标记样本集,使用传统的监督学习方法,常常会由于训练样本的不充足,使得其刻画总体分布的能力减弱,从而影响了学习器泛化性能。那如何利用未标记的样本数据呢?
一种简单的做法是通过专家知识对这些未标记的样本进行打标,但随之而来的就是巨大的人力耗费。若我们先使用有标记的样本数据集训练出一个学习器,再基于该学习器对未标记的样本进行预测,从中挑选出不确定性高或分类置信度低的样本来咨询专家并进行打标,最后使用扩充后的训练集重新训练学习器,这样便能大幅度降低标记成本,这便是主动学习(active learning),其目标是使用尽量少的/有价值的咨询来获得更好的性能。
显然,主动学习需要与外界进行交互/查询/打标,其本质上仍然属于一种监督学习。事实上,无标记样本虽未包含标记信息,但它们与有标记样本一样都是从总体中独立同分布采样得到,因此它们所包含的数据分布信息对学习器的训练大有裨益。如何让学习过程不依赖外界的咨询交互,自动利用未标记样本所包含的分布信息的方法便是半监督学习(semi-supervised learning),即训练集同时包含有标记样本数据和未标记样本数据。

此外,半监督学习还可以进一步划分为纯半监督学习和直推学习,两者的区别在于:前者假定训练数据集中的未标记数据并非待预测数据,而后者假定学习过程中的未标记数据就是待预测数据。主动学习、纯半监督学习以及直推学习三者的概念如下图所示:

14.1 生成式方法
生成式方法(generative methods)是基于生成式模型的方法,即先对联合分布P(x,c)建模,从而进一步求解 P(c | x),此类方法假定样本数据服从一个潜在的分布,因此需要充分可靠的先验知识。例如:前面已经接触到的贝叶斯分类器与高斯混合聚类,都属于生成式模型。现假定总体是一个高斯混合分布,即由多个高斯分布组合形成,从而一个子高斯分布就代表一个类簇(类别)。高斯混合分布的概率密度函数如下所示:

不失一般性,假设类簇与真实的类别按照顺序一一对应,即第i个类簇对应第i个高斯混合成分。与高斯混合聚类类似地,这里的主要任务也是估计出各个高斯混合成分的参数以及混合系数,不同的是:对于有标记样本,不再是可能属于每一个类簇,而是只能属于真实类标对应的特定类簇。

直观上来看,基于半监督的高斯混合模型有机地整合了贝叶斯分类器与高斯混合聚类的核心思想,有效地利用了未标记样本数据隐含的分布信息,从而使得参数的估计更加准确。同样地,这里也要召唤出之前的EM大法进行求解,首先对各个高斯混合成分的参数及混合系数进行随机初始化,计算出各个PM(即γji,第i个样本属于j类,有标记样本则直接属于特定类),再最大化似然函数(即LL(D)分别对α、u和∑求偏导 ),对参数进行迭代更新。

当参数迭代更新收敛后,对于待预测样本x,便可以像贝叶斯分类器那样计算出样本属于每个类簇的后验概率,接着找出概率最大的即可:

可以看出:基于生成式模型的方法十分依赖于对潜在数据分布的假设,即假设的分布要能和真实分布相吻合,否则利用未标记的样本数据反倒会在错误的道路上渐行渐远,从而降低学习器的泛化性能。因此,此类方法要求极强的领域知识和掐指观天的本领。
14.2 半监督SVM
监督学习中的SVM试图找到一个划分超平面,使得两侧支持向量之间的间隔最大,即“最大划分间隔”思想。对于半监督学习,S3VM则考虑超平面需穿过数据低密度的区域。TSVM是半监督支持向量机中的最著名代表,其核心思想是:尝试为未标记样本找到合适的标记指派,使得超平面划分后的间隔最大化。TSVM采用局部搜索的策略来进行迭代求解,即首先使用有标记样本集训练出一个初始SVM,接着使用该学习器对未标记样本进行打标,这样所有样本都有了标记,并基于这些有标记的样本重新训练SVM,之后再寻找易出错样本不断调整。整个算法流程如下所示:


14.3 基于分歧的方法
基于分歧的方法通过多个学习器之间的**分歧(disagreement)/多样性(diversity)**来利用未标记样本数据,协同训练就是其中的一种经典方法。协同训练最初是针对于多视图(multi-view)数据而设计的,多视图数据指的是样本对象具有多个属性集,每个属性集则对应一个试图。例如:电影数据中就包含画面类属性和声音类属性,这样画面类属性的集合就对应着一个视图。首先引入两个关于视图的重要性质:
相容性:即使用单个视图数据训练出的学习器的输出空间是一致的。例如都是{好,坏}、{+1,-1}等。 互补性:即不同视图所提供的信息是互补/相辅相成的,实质上这里体现的就是集成学习的思想。
协同训练正是很好地利用了多视图数据的“相容互补性”,其基本的思想是:首先基于有标记样本数据在每个视图上都训练一个初始分类器,然后让每个分类器去挑选分类置信度最高的样本并赋予标记,并将带有伪标记的样本数据传给另一个分类器去学习,从而你依我侬/共同进步。

14.4 半监督聚类
前面提到的几种方法都是借助无标记样本数据来辅助监督学习的训练过程,从而使得学习更加充分/泛化性能得到提升;半监督聚类则是借助已有的监督信息来辅助聚类的过程。一般而言,监督信息大致有两种类型:
必连与勿连约束:必连指的是两个样本必须在同一个类簇,勿连则是必不在同一个类簇。 标记信息:少量的样本带有真实的标记。
下面主要介绍两种基于半监督的K-Means聚类算法:第一种是数据集包含一些必连与勿连关系,另外一种则是包含少量带有标记的样本。两种算法的基本思想都十分的简单:对于带有约束关系的k-均值算法,在迭代过程中对每个样本划分类簇时,需要检测当前划分是否满足约束关系,若不满足则会将该样本划分到距离次小对应的类簇中,再继续检测是否满足约束关系,直到完成所有样本的划分。算法流程如下图所示:

对于带有少量标记样本的k-均值算法,则可以利用这些有标记样本进行类中心的指定,同时在对样本进行划分时,不需要改变这些有标记样本的簇隶属关系,直接将其划分到对应类簇即可。算法流程如下所示:

十五. 概率图模型
现在再来谈谈机器学习的核心价值观,可以更通俗地理解为:根据一些已观察到的证据来推断未知,更具哲学性地可以阐述为:未来的发展总是遵循着历史的规律。其中基于概率的模型将学习任务归结为计算变量的概率分布,正如之前已经提到的:生成式模型先对联合分布进行建模,从而再来求解后验概率,例如:贝叶斯分类器先对联合分布进行最大似然估计,从而便可以计算类条件概率;判别式模型则是直接对条件分布进行建模。
概率图模型(probabilistic graphical model)是一类用图结构来表达各属性之间相关关系的概率模型,一般而言:图中的一个结点表示一个或一组随机变量,结点之间的边则表示变量间的相关关系,从而形成了一张“变量关系图”。若使用有向的边来表达变量之间的依赖关系,这样的有向关系图称为贝叶斯网(Bayesian nerwork)或有向图模型;若使用无向边,则称为马尔可夫网(Markov network)或无向图模型。
15.1 隐马尔可夫模型(HMM)
隐马尔可夫模型(Hidden Markov Model,简称HMM)是结构最简单的一种贝叶斯网,在语音识别与自然语言处理领域上有着广泛的应用。HMM中的变量分为两组:状态变量与观测变量,其中状态变量一般是未知的,因此又称为“隐变量”,观测变量则是已知的输出值。在隐马尔可夫模型中,变量之间的依赖关系遵循如下两个规则:
1. 观测变量的取值仅依赖于状态变量; 2. 下一个状态的取值仅依赖于当前状态,通俗来讲:现在决定未来,未来与过去无关,这就是著名的马尔可夫性。

基于上述变量之间的依赖关系,我们很容易写出隐马尔可夫模型中所有变量的联合概率分布:

易知:欲确定一个HMM模型需要以下三组参数:

当确定了一个HMM模型的三个参数后,便按照下面的规则来生成观测值序列:

在实际应用中,HMM模型的发力点主要体现在下述三个问题上:

15.1.1 HMM评估问题
HMM评估问题指的是:给定了模型的三个参数与观测值序列,求该观测值序列出现的概率。例如:对于赌场问题,便可以依据骰子掷出的结果序列来计算该结果序列出现的可能性,若小概率的事件发生了则可认为赌场的骰子有作弊的可能。解决该问题使用的是前向算法,即步步为营,自底向上的方式逐步增加序列的长度,直到获得目标概率值。在前向算法中,定义了一个前向变量,即给定观察值序列且t时刻的状态为Si的概率:

基于前向变量,很容易得到该问题的递推关系及终止条件:

因此可使用动态规划法,从最小的子问题开始,通过填表格的形式一步一步计算出目标结果。
15.1.2 HMM解码问题
HMM解码问题指的是:给定了模型的三个参数与观测值序列,求可能性最大的状态序列。例如:在语音识别问题中,人说话形成的数字信号对应着观测值序列,对应的具体文字则是状态序列,从数字信号转化为文字正是对应着根据观测值序列推断最有可能的状态值序列。解决该问题使用的是Viterbi算法,与前向算法十分类似地,Viterbi算法定义了一个Viterbi变量,也是采用动态规划的方法,自底向上逐步求解。

15.1.3 HMM学习问题
HMM学习问题指的是:给定观测值序列,如何调整模型的参数使得该序列出现的概率最大。这便转化成了机器学习问题,即从给定的观测值序列中学习出一个HMM模型,该问题正是EM算法的经典案例之一。其思想也十分简单:对于给定的观测值序列,如果我们能够按照该序列潜在的规律来调整模型的三个参数,则可以使得该序列出现的可能性最大。假设状态值序列也已知,则很容易计算出与该序列最契合的模型参数:

但一般状态值序列都是不可观测的,且即使给定观测值序列与模型参数,状态序列仍然遭遇组合爆炸。因此上面这种简单的统计方法就行不通了,若将状态值序列看作为隐变量,这时便可以考虑使用EM算法来对该问题进行求解:
【1】首先对HMM模型的三个参数进行随机初始化; 【2】根据模型的参数与观测值序列,计算t时刻状态为i且t+1时刻状态为j的概率以及t时刻状态为i的概率。

【3】接着便可以对模型的三个参数进行重新估计:

【4】重复步骤2-3,直至三个参数值收敛,便得到了最终的HMM模型。
15.2 马尔可夫随机场(MRF)
马尔可夫随机场(Markov Random Field)是一种典型的马尔可夫网,即使用无向边来表达变量间的依赖关系。在马尔可夫随机场中,对于关系图中的一个子集,若任意两结点间都有边连接,则称该子集为一个团;若再加一个结点便不能形成团,则称该子集为极大团。MRF使用势函数来定义多个变量的概率分布函数,其中每个(极大)团对应一个势函数,一般团中的变量关系也体现在它所对应的极大团中,因此常常基于极大团来定义变量的联合概率分布函数。具体而言,若所有变量构成的极大团的集合为C,则MRF的联合概率函数可以定义为:

对于条件独立性,马尔可夫随机场通过分离集来实现条件独立,若A结点集必须经过C结点集才能到达B结点集,则称C为分离集。书上给出了一个简单情形下的条件独立证明过程,十分贴切易懂,此处不再展开。基于分离集的概念,得到了MRF的三个性质:
全局马尔可夫性:给定两个变量子集的分离集,则这两个变量子集条件独立。 局部马尔可夫性:给定某变量的邻接变量,则该变量与其它变量条件独立。 成对马尔可夫性:给定所有其他变量,两个非邻接变量条件独立。

对于MRF中的势函数,势函数主要用于描述团中变量之间的相关关系,且要求为非负函数,直观来看:势函数需要在偏好的变量取值上函数值较大,例如:若x1与x2成正相关,则需要将这种关系反映在势函数的函数值中。一般我们常使用指数函数来定义势函数:

15.3 条件随机场(CRF)
前面所讲到的隐马尔可夫模型和马尔可夫随机场都属于生成式模型,即对联合概率进行建模,条件随机场则是对条件分布进行建模。CRF试图在给定观测值序列后,对状态序列的概率分布进行建模,即P(y | x)。直观上看:CRF与HMM的解码问题十分类似,都是在给定观测值序列后,研究状态序列可能的取值。CRF可以有多种结构,只需保证状态序列满足马尔可夫性即可,一般我们常使用的是链式条件随机场:

与马尔可夫随机场定义联合概率类似地,CRF也通过团以及势函数的概念来定义条件概率P(y | x)。在给定观测值序列的条件下,链式条件随机场主要包含两种团结构:单个状态团及相邻状态团,通过引入两类特征函数便可以定义出目标条件概率:

以词性标注为例,如何判断给出的一个标注序列靠谱不靠谱呢?转移特征函数主要判定两个相邻的标注是否合理,例如:动词+动词显然语法不通;状态特征函数则判定观测值与对应的标注是否合理,例如: ly结尾的词-->副词较合理。因此我们可以定义一个特征函数集合,用这个特征函数集合来为一个标注序列打分,并据此选出最靠谱的标注序列。也就是说,每一个特征函数(对应一种规则)都可以用来为一个标注序列评分,把集合中所有特征函数对同一个标注序列的评分综合起来,就是这个标注序列最终的评分值。可以看出:特征函数是一些经验的特性。
15.4 学习与推断
对于生成式模型,通常我们都是先对变量的联合概率分布进行建模,接着再求出目标变量的边际分布(marginal distribution),那如何从联合概率得到边际分布呢?这便是学习与推断。下面主要介绍两种精确推断的方法:变量消去与信念传播。
15.4.1 变量消去
变量消去利用条件独立性来消减计算目标概率值所需的计算量,它通过运用乘法与加法的分配率,将对变量的积的求和问题转化为对部分变量交替进行求积与求和的问题,从而将每次的运算控制在局部,达到简化运算的目的。

15.4.2 信念传播
若将变量求和操作看作是一种消息的传递过程,信念传播可以理解成:一个节点在接收到所有其它节点的消息后才向另一个节点发送消息,同时当前节点的边际概率正比于他所接收的消息的乘积:

因此只需要经过下面两个步骤,便可以完成所有的消息传递过程。利用动态规划法的思想记录传递过程中的所有消息,当计算某个结点的边际概率分布时,只需直接取出传到该结点的消息即可,从而避免了计算多个边际分布时的冗余计算问题。
1.指定一个根节点,从所有的叶节点开始向根节点传递消息,直到根节点收到所有邻接结点的消息**(从叶到根); 2.从根节点开始向叶节点传递消息,直到所有叶节点均收到消息(从根到叶)**。

15.5 LDA话题模型
话题模型主要用于处理文本类数据,其中隐狄利克雷分配模型(Latent Dirichlet Allocation,简称LDA)是话题模型的杰出代表。在话题模型中,有以下几个基本概念:词(word)、文档(document)、话题(topic)。
词:最基本的离散单元; 文档:由一组词组成,词在文档中不计顺序; 话题:由一组特定的词组成,这组词具有较强的相关关系。
在现实任务中,一般我们可以得出一个文档的词频分布,但不知道该文档对应着哪些话题,LDA话题模型正是为了解决这个问题。具体来说:LDA认为每篇文档包含多个话题,且其中每一个词都对应着一个话题。因此可以假设文档是通过如下方式生成:

这样一个文档中的所有词都可以认为是通过话题模型来生成的,当已知一个文档的词频分布后(即一个N维向量,N为词库大小),则可以认为:每一个词频元素都对应着一个话题,而话题对应的词频分布则影响着该词频元素的大小。因此很容易写出LDA模型对应的联合概率函数:

从上图可以看出,LDA的三个表示层被三种颜色表示出来:
corpus-level(红色): α和β表示语料级别的参数,也就是每个文档都一样,因此生成过程只采样一次。 document-level(橙色): θ是文档级别的变量,每个文档对应一个θ。 word-level(绿色): z和w都是单词级别变量,z由θ生成,w由z和β共同生成,一个单词w对应一个主题z。
通过上面对LDA生成模型的讨论,可以知道LDA模型主要是想从给定的输入语料中学习训练出两个控制参数α和β,当学习出了这两个控制参数就确定了模型,便可以用来生成文档。其中α和β分别对应以下各个信息:
α:分布p(θ)需要一个向量参数,即Dirichlet分布的参数,用于生成一个主题θ向量; β:各个主题对应的单词概率分布矩阵p(w|z)。
把w当做观察变量,θ和z当做隐藏变量,就可以通过EM算法学习出α和β,求解过程中遇到后验概率p(θ,z|w)无法直接求解,需要找一个似然函数下界来近似求解,原作者使用基于分解(factorization)假设的变分法(varialtional inference)进行计算,用到了EM算法。每次E-step输入α和β,计算似然函数,M-step最大化这个似然函数,算出α和β,不断迭代直到收敛。
十六. 强化学习
强化学习(Reinforcement Learning,简称RL)是机器学习的一个重要分支,前段时间人机大战的主角AlphaGo正是以强化学习为核心技术。在强化学习中,包含两种基本的元素:状态与动作,在某个状态下执行某种动作,这便是一种策略,学习器要做的就是通过不断地探索学习,从而获得一个好的策略。例如:在围棋中,一种落棋的局面就是一种状态,若能知道每种局面下的最优落子动作,那就攻无不克/百战不殆了~
若将状态看作为属性,动作看作为标记,易知:监督学习和强化学习都是在试图寻找一个映射,从已知属性/状态推断出标记/动作,这样强化学习中的策略相当于监督学习中的分类/回归器。但在实际问题中,强化学习并没有监督学习那样的标记信息,通常都是在尝试动作后才能获得结果,因此强化学习是通过反馈的结果信息不断调整之前的策略,从而算法能够学习到:在什么样的状态下选择什么样的动作可以获得最好的结果。
16.1 基本要素
强化学习任务通常使用马尔可夫决策过程(Markov Decision Process,简称MDP)来描述,具体而言:机器处在一个环境中,每个状态为机器对当前环境的感知;机器只能通过动作来影响环境,当机器执行一个动作后,会使得环境按某种概率转移到另一个状态;同时,环境会根据潜在的奖赏函数反馈给机器一个奖赏。综合而言,强化学习主要包含四个要素:状态、动作、转移概率以及奖赏函数。
状态(X):机器对环境的感知,所有可能的状态称为状态空间; 动作(A):机器所采取的动作,所有能采取的动作构成动作空间; 转移概率(P):当执行某个动作后,当前状态会以某种概率转移到另一个状态; 奖赏函数(R):在状态转移的同时,环境给反馈给机器一个奖赏。

因此,强化学习的主要任务就是通过在环境中不断地尝试,根据尝试获得的反馈信息调整策略,最终生成一个较好的策略π,机器根据这个策略便能知道在什么状态下应该执行什么动作。常见的策略表示方法有以下两种:
确定性策略:π(x)=a,即在状态x下执行a动作; 随机性策略:P=π(x,a),即在状态x下执行a动作的概率。
一个策略的优劣取决于长期执行这一策略后的累积奖赏,换句话说:可以使用累积奖赏来评估策略的好坏,最优策略则表示在初始状态下一直执行该策略后,最后的累积奖赏值最高。长期累积奖赏通常使用下述两种计算方法:

16.2 K摇摆赌博机
首先我们考虑强化学习最简单的情形:仅考虑一步操作,即在状态x下只需执行一次动作a便能观察到奖赏结果。易知:欲最大化单步奖赏,我们需要知道每个动作带来的期望奖赏值,这样便能选择奖赏值最大的动作来执行。若每个动作的奖赏值为确定值,则只需要将每个动作尝试一遍即可,但大多数情形下,一个动作的奖赏值来源于一个概率分布,因此需要进行多次的尝试。
单步强化学习实质上是K-摇臂赌博机(K-armed bandit)的原型,一般我们尝试动作的次数是有限的,那如何利用有限的次数进行有效地探索呢?这里有两种基本的想法:
仅探索法:将尝试的机会平均分给每一个动作,即轮流执行,最终将每个动作的平均奖赏作为期望奖赏的近似值。 仅利用法:将尝试的机会分给当前平均奖赏值最大的动作,隐含着让一部分人先富起来的思想。
可以看出:上述两种方法是相互矛盾的,仅探索法能较好地估算每个动作的期望奖赏,但是没能根据当前的反馈结果调整尝试策略;仅利用法在每次尝试之后都更新尝试策略,符合强化学习的思(tao)维(lu),但容易找不到最优动作。因此需要在这两者之间进行折中。
16.2.1 ε-贪心
ε-贪心法基于一个概率来对探索和利用进行折中,具体而言:在每次尝试时,以ε的概率进行探索,即以均匀概率随机选择一个动作;以1-ε的概率进行利用,即选择当前最优的动作。ε-贪心法只需记录每个动作的当前平均奖赏值与被选中的次数,便可以增量式更新。

16.2.2 Softmax
Softmax算法则基于当前每个动作的平均奖赏值来对探索和利用进行折中,Softmax函数将一组值转化为一组概率,值越大对应的概率也越高,因此当前平均奖赏值越高的动作被选中的几率也越大。Softmax函数如下所示:

16.3 有模型学习
若学习任务中的四个要素都已知,即状态空间、动作空间、转移概率以及奖赏函数都已经给出,这样的情形称为“有模型学习”。假设状态空间和动作空间均为有限,即均为离散值,这样我们不用通过尝试便可以对某个策略进行评估。
16.3.1 策略评估
前面提到:在模型已知的前提下,我们可以对任意策略的进行评估(后续会给出演算过程)。一般常使用以下两种值函数来评估某个策略的优劣:
状态值函数(V):V(x),即从状态x出发,使用π策略所带来的累积奖赏; 状态-动作值函数(Q):Q(x,a),即从状态x出发,执行动作a后再使用π策略所带来的累积奖赏。
根据累积奖赏的定义,我们可以引入T步累积奖赏与r折扣累积奖赏:

由于MDP具有马尔可夫性,即现在决定未来,将来和过去无关,我们很容易找到值函数的递归关系:

类似地,对于r折扣累积奖赏可以得到:

易知:当模型已知时,策略的评估问题转化为一种动态规划问题,即以填表格的形式自底向上,先求解每个状态的单步累积奖赏,再求解每个状态的两步累积奖赏,一直迭代逐步求解出每个状态的T步累积奖赏。算法流程如下所示:

对于状态-动作值函数,只需通过简单的转化便可得到:

16.3.2 策略改进
理想的策略应能使得每个状态的累积奖赏之和最大,简单来理解就是:不管处于什么状态,只要通过该策略执行动作,总能得到较好的结果。因此对于给定的某个策略,我们需要对其进行改进,从而得到最优的值函数。

最优Bellman等式改进策略的方式为:将策略选择的动作改为当前最优的动作,而不是像之前那样对每种可能的动作进行求和。易知:选择当前最优动作相当于将所有的概率都赋给累积奖赏值最大的动作,因此每次改进都会使得值函数单调递增。

将策略评估与策略改进结合起来,我们便得到了生成最优策略的方法:先给定一个随机策略,现对该策略进行评估,然后再改进,接着再评估/改进一直到策略收敛、不再发生改变。这便是策略迭代算法,算法流程如下所示:

可以看出:策略迭代法在每次改进策略后都要对策略进行重新评估,因此比较耗时。若从最优化值函数的角度出发,即先迭代得到最优的值函数,再来计算如何改变策略,这便是值迭代算法,算法流程如下所示:

16.4 蒙特卡罗强化学习
在现实的强化学习任务中,环境的转移函数与奖赏函数往往很难得知,因此我们需要考虑在不依赖于环境参数的条件下建立强化学习模型,这便是免模型学习。蒙特卡罗强化学习便是其中的一种经典方法。
由于模型参数未知,状态值函数不能像之前那样进行全概率展开,从而运用动态规划法求解。一种直接的方法便是通过采样来对策略进行评估/估算其值函数,蒙特卡罗强化学习正是基于采样来估计状态-动作值函数:对采样轨迹中的每一对状态-动作,记录其后的奖赏值之和,作为该状态-动作的一次累积奖赏,通过多次采样后,使用累积奖赏的平均作为状态-动作值的估计,并引入ε-贪心策略保证采样的多样性。

在上面的算法流程中,被评估和被改进的都是同一个策略,因此称为同策略蒙特卡罗强化学习算法。引入ε-贪心仅是为了便于采样评估,而在使用策略时并不需要ε-贪心,那能否仅在评估时使用ε-贪心策略,而在改进时使用原始策略呢?这便是异策略蒙特卡罗强化学习算法。

16.5 AlphaGo原理浅析
本篇一开始便提到强化学习是AlphaGo的核心技术之一,刚好借着这个东风将AlphaGo的工作原理了解一番。正如人类下棋那般“手下一步棋,心想三步棋”,Alphago也正是这个思想,当处于一个状态时,机器会暗地里进行多次的尝试/采样,并基于反馈回来的结果信息改进估值函数,从而最终通过增强版的估值函数来选择最优的落子动作。
其中便涉及到了三个主要的问题:(1)如何确定估值函数(2)如何进行采样(3)如何基于反馈信息改进估值函数,这正对应着AlphaGo的三大核心模块:深度学习、蒙特卡罗搜索树、强化学习。
1.深度学习(拟合估值函数)
由于围棋的状态空间巨大,像蒙特卡罗强化学习那样通过采样来确定值函数就行不通了。在围棋中,状态值函数可以看作为一种局面函数,状态-动作值函数可以看作一种策略函数,若我们能获得这两个估值函数,便可以根据这两个函数来完成:(1)衡量当前局面的价值;(2)选择当前最优的动作。那如何精确地估计这两个估值函数呢?这就用到了深度学习,通过大量的对弈数据自动学习出特征,从而拟合出估值函数。
2.蒙特卡罗搜索树(采样)
蒙特卡罗树是一种经典的搜索框架,它通过反复地采样模拟对局来探索状态空间。具体表现在:从当前状态开始,利用策略函数尽可能选择当前最优的动作,同时也引入随机性来减小估值错误带来的负面影响,从而模拟棋局运行,使得棋盘达到终局或一定步数后停止。

3.强化学习(调整估值函数)
在使用蒙特卡罗搜索树进行多次采样后,每次采样都会反馈后续的局面信息(利用局面函数进行评价),根据反馈回来的结果信息自动调整两个估值函数的参数,这便是强化学习的核心思想,最后基于改进后的策略函数选择出当前最优的落子动作。

熵、交叉熵和KL散度
熵(Entropy)的介绍 我们以天气预报为例子,进行熵的介绍.
-
假如只有 2 种天气,sunny 和 rainy ,那么明天对于每一种天气来说,各有 50% 的可能性.
-
此时气象部门告诉你明天是rainy,他其实减少了你的不确定信息.
-
所以,天气部门给了你 1 bit的有效信息(因为此时只有两种可能性).
-
假如只有8种天气,每一种天气出现是等可能的.
-
此时气象部门告诉你明天是 Rainy ,他其实减少了你的不确定信息,也就是告诉了你有效信息.
-
所以,天气部门给了你 3 bit的有效信息(因为8种状态需要 2^3=8 ,需要 3 bit来表示.
-
所以,有效信息的计算可以使用 log 来进行计算,计算过程如下
-
上面所有的情况都是等概率出现的,假设各种情况出现的概率不是相等的.
-
例如有75%的可能性是Sunny,25%的可能性是Rainy.
-
如果气象部门告诉你明天是Rainy
- 我们会使用概率的倒数, (概率越小,有效信息越多)
- 接着计算有效信息, ,log的等价计算)
- 因为和本来的概率相差比较大,所以获得的有效信息比较多(本来是 Rainy 的可能性小)
-
如果气象部门告诉你明天是 Sunny
- 同样计算此时的有效信息,
- 因为和本来的概率相差比较小,所以获得的信息比较少(本来是 Sunny 的可能性大)
-
从气象部门获得的信息的平均值(这个就是熵
- 简单解释: 有 75% 的可能性是Sunny ,得到晴天的有效信息是 0.41 ,所以是
-
于是我们得到了熵的计算公式.
- 熵是用来衡量获取的信息的平均值,没有获取什么信息量,则 Entropy 接近 0 .
- 下面是熵的计算公式
交叉熵(Cross-Entropy)的介绍 对于交叉熵的介绍,我们还是以天气预报作为例子来进行讲解.
- 交叉熵(Cross-Entropy)可以理解为平均的 message length ,即平均信息长度.
- 现在有 8 种天气,于是每一种天气我们可以使用 3 bit 来进行表示(000,001,010,011...)
- 此时 average message length = 3 ,所以此时Cross-Entropy
现在假设你住在一个 sunny region ,出现晴天的可能性比较大(即每一种天气不是等可能出现的),下图是每一种天气的概率.
我们来计算一下此时的熵(ntropy) ,计算的式子如下所示:
- 此时有效的信息是 2.23 bit.
- 所以再使用上面的编码方式(都使用 3 会有咒余.
- 也就是说我们每次发出 3 bit ,接收者有效信息为2.23 bit.
- 这时我们可以修改天气的 encode 的方式,可以给经常出现的天气比较小的 code 来进行表示,于是我们可以按照下图对每一种天气进行 encode .
此时的平均长度的计算如下所示(每一种天气的概率该天气code的长度): 此时的平均长度为 2.42 bit,可以看到比一开始的 3 bit有所减少.
如果我们使用相同的 code ,但是此时对应的天气的概率是不相同的,此时计算出的平均长度就会发生改变.此时每一种天气的概率如下图所示:
于是此时的信息的平均长度就是 4.58 bit ,比 Entropy 大很多 (如上图所示,我们给了概率很小的天 气的 code 也很小,概率很大的天气的 code 也很大,此时就会导致计算出的平均长度很大),下面是平均长度的计算的式子. 我们如何来理解我们给每一种天气的 code 呢,其实我们可以理解为这就是我们对每一种天气发生的可能性的预测,我们会给出现概率比较大的天气比较短的 code ,这里的概率是我们假设的,即我们有一个估计的概率,我们估计这个天气的概率比较大,所以给这个天气比较短的 code.
下图中可以表示出我们预测的q(predicted distribution)和真实分布p(true distribution).可以看到此时我们的预测概率 与真实分布 之间相差很大(此时计算出的交叉熵就会比较大)
关于上面 code 长度与概率的转换,我们可以这么来进行理解,对于概率为 的信息,他的有效信息为 .若此时 code 长度为 ,我们问概率为多少的信息的有效信息为 n ,即求解 ,则 ,所以我们就可以求出 长度与概率的转换.此时,我们就可以定义交叉熵(Cross-Entropy),这里会有两个变量,分别是p(真实的分布)和 q(预测的概率): 这个交叉熵公式的意思就是在计算消息的平均长度,我们可以这样来进行理解.
- 是将预测概率转换为 code 的长度(这里看上面 code 长度与概率的转换)
- 接着我们再将 code 的长度 乘上出现的概率 真实的概率
我们简单说明一下熵(Entropy),和交叉熵(Cross-Entropy)的性质:
- 如果预测结果是好的,,那么 p 和 q 的分布是相似的,此时 Cross-Entropy 与 Entropy 是相似 的.
- 如果 p 和 q 有很大的不同,那么 Cross-Entropy 会比 Entropy 大.
- 其中 Cross-Entropy 比 Entropy 大的部分,我们称为 relative entropy ,或是Kullback- Leibler Divergence(KL Divergence),这个就是KL-散度,我们会在后面进行详细的介 绍.
- 也就是说,三者的关系为:Cross-Entropy=Entropy+ KL Divergence
在进行分类问题的时候,我们通常会将 loss 函数设置为交叉熵(Cross-Entropy),其实现在来看这个也是很好理解,我们会有我们预测的概率 q 和实际的概率 p ,若 p 和 q 相似,则交叉熵小,若 p 和 q 不相似,则交叉熵大.
有一个要注意的是,我们通常在使用的时候会使用 10 为底的 log ,但是这个不影响 ,因为 ,我们可以通过公式进行转换.
在 PyTorch 中,CrossEntropyLoss 不是直接按照上面进行计算的,他是包含了 Softmax 的步骤的.关于在 PyTorch 中 CrossEntropyLoss 的实际计算: 详细介绍:PyTorch中交叉熵的计算CrossEntropyLoss
交叉熵损失函数
交叉熵损失函数(Cross-Entropy Loss Function)一般用于分类问题.假设样本的标签 𝑦 ∈ {1,⋯,𝐶} 为离散的类别,模型 的输出为类别标签的条件概率分布,即 并满足 我们可以用一个 维的one-hot向量 来表示样本标签.假设样本的标签为 ,那么标签向量只有第 维的值为 1 ,其余元素的值都为 0 .标签向量 可以看作样本标签的真实条件概率分布 ,即第 维(记为 ) 是类别为 的真实条件概率.假设样本的类别为 ,那么它属于第 类的概率为 1 ,属于其他类的概率为 0 . 对于两个概率分布,一般可以用交叉熵来衡量它们的差异.标签的真实分布 和模型预测分布 之间的交叉熵为 比如对于三分类问题,一个样本的标签向量为,模型预测的标签分布为,则它们的交叉熵为. 因为 为one-hot向量,上式也可以写为 其中 可以看作真实类别 的似然函数.因此,交叉熵损失函数也就是负对数似然函数(Negative Log-Likelihood).
交叉熵最小值证明
为什么当 与 的分布一致时, 有最小值:
证明如下: 因为 为定义域上的凸函数 根据琴生不等式有 其中
将 代入则有: 所以 当 时, 等号成立.
神经网络
书可参考邱锡鹏:神经网络与深度学习
一. 神经元
时刻思考各个激活函数的特点,如合理性?梯度爆炸?梯度消失?
1. Sigmoid型函数
Sigmoid型函数是指一类S型曲线函数,为两端饱和函数.常用的Sigmoid型函数有Logistic函数和Tanh函数.
>对于函数 ,若 时,其导数 ,则称其为左饱和.若 时,其导数 ,则称其为右饱和.当同时满足左、右饱和时,就称为两端饱和.
1.1 Logistic型函数
1.2 Tanh函数
Tanh函数可以看作放大并平移的Logistic函数,其值域是(−1,1)
下图给出了Logistic函数和Tanh函数的形状.Tanh函数的输出是零中心化的(Zero-Centered),而Logistic函数的输出恒大于0.非零中心化的输出会使得其后一层的神经元的输入发生偏置偏移(Bias Shift),并进一步使得梯度下降的收敛速度变慢.
2. ReLU函数
ReLU(Rectified Linear Unit,修正线性单元),也叫Rectifier函数,是目前深度神经网络中经常使用的激活函数.ReLU实际上是一个斜坡(ramp)函数,定义为 ReLU神经元训练时比较容易“死亡”,在训练时,如果参数在一次不恰当的更新后,第一个隐藏层中的某个ReLU神经元在所有的训练数据上都不能被激活,那么这个神经元自身参数的梯度永远都会是0,在以后的训练过程中永远不能被激活.这种现象称为死亡ReLU问题(DyingReLU Problem).故有以下变种
2.1 带泄露的ReLU
带泄露的ReLU(Leaky ReLU)在输入 时,保持一个很小的梯度 .这样当神经元非激活时也能有一个非零的梯度可以更新参数,避免永远不能被激活.
2.2 带参数的ReLU
带参数的ReLU(Parametric ReLU,PReLU)引入一个可学习的参数,不同神经元可以有不同的参数[He et al.,2015].对于第𝑖个神经元,其PReLU的定义为
2.3 ELU函数
ELU(Exponential Linear Unit,指数线性单元)是一个近似的零中心化的非线性函数,其定义为 其中𝛾 ≥ 0是一个超参数,决定𝑥 ≤ 0时的饱和曲线,并调整输出均值在0附近.
2.4 Softplus函数
Softplus函数可以看作Rectifier函数的平滑版本,其定义为
Softplus函数其导数刚好是Logistic函数.Softplus函数虽然也具有单侧抑制、宽兴奋边界的特性,却没有稀疏激活性.
3. Swish函数
Swish函数是一种自门控(Self-Gated)激活函数,定义为
其中为Logistic函数, 为可学习的参数或一个固定超参数. 可以看作一种软性的门控机制.当 接近于1时,门处于“开”状态,激活函数的输出近似于𝑥本身;当 接近于0时,门的状态为“关”,激活函数的输出近似于0
当 时,Swish函数变成线性函数 .当 时,Swish函数在 时近似线性,在 时近似饱和,同时具有一定的非单调性.当 时, 趋向于离散的0-1函数,Swish函数近似为ReLU函数.因此,Swish函数可以看作线性函数和ReLU函数之间的非线性插值函数,其程度由参数 控制.
4. GELU函数
TODO:GELU函数
5. Maxout单元
TODO:Maxout单元
二. 网络结构
1. 网络结构总述
目前为止,常用的神经网络有如下三种:
- 前馈网络:整个网络中的信息是朝一个方向传播,没有反向的信息传播,可以用一个有向无环路图表示.前馈网络包括全连接前馈网络和卷积神经网络等.前馈网络可以看作一个函数,通过简单非线性函数的多次复合,实现输入空间到输出空间的复杂映射.
- 记忆网络:也称为反馈网络,网络中的神经元不但可以接收其他神经元的信息,也可以接收自己的历史信息.和前馈网络相比,记忆网络中的神经元具有记忆功能,在不同的时刻具有不同的状态.记忆神经网络中的信息传播可以是单向或双向传递,因此可用一个有向循环图或无向图来表示.记忆网络包括循环神经网络、Hopfield网络、玻尔兹曼机、受限玻尔兹曼机等.记忆网络可以看作一个程序,具有更强的计算和记忆能力.为了增强记忆网络的记忆容量,可以引入外部记忆单元和读写机制,用来保存一些网络的中间状态,称为记忆增强神经网络(Memory Augmented NeuralNetwork,MANN),比如神经图灵机和记忆网络等.
- 图网络:实际应用中很多数据是图结构的数据,比如知识图谱、社交网络、分子(Molecular)网络等.图网络是定义在图结构数据上的神经网络.图中每个节点都由一个或一组神经元构成.节点之间的连接可以是有向的,也可以是无向的.每个节点可以收到来自相邻节点或自身的信息.图网络是前馈网络和记忆网络的泛化,包含很多不同的实现方式,比如图卷积网络(Graph Convolutional Network,GCN)、图注意力网络(Graph Attention Network,GAT)、消息传递神经网络(Message Passing Neural Network,MPNN)等.
三. 小批量梯度下降
小批量梯度下降法(Mini-BatchGradient Descent).
令 表示一个深度神经网络, 为网络参数,在使用小批量梯度下降进行优化时,每次选取 个训练样本 .第 次迭代时损失函数关于参数 的偏导数为 其中 为可微分的损失函数, 为批量大小(batch size).
前馈神经网络
一. 前馈神经网络
前馈神经网络(Feedforward Neural Network,FNN)是最早发明的简单人工神经网络.前馈神经网络也经常称为多层感知器(Multi-Layer Perceptron,MLP).
在前馈神经网络中,各神经元分别属于不同的层.每一层的神经元可以接收前一层神经元的信号,并产生信号输出到下一层.第 0 层称为输入层,最后一层称为输出层,其他中间层称为隐藏层.整个网络中无反馈,信号从输入层向输出层单向传播,可用一个有向无环图表示.
下面用到的记号:
- :神经网络的层数
- :第 层神经元的个数
- :第 层神经元的激活函数
- :第 层到第 层的权重矩阵
- :第 层神经元的净输入 净活性值
- :第 层神经元的输出 活性值
令 ,前馈神经网络通过不断迭代下面公式进行信息传播:
首先根据第 层神经元的活性值 ( Activation ) 计算出第 层神经元的净活性值 ( Net Activation ) ,然后经过一个激活函数得到第 层神经元的活性 值因此,我们也可以把每个神经层看作一个仿射变换 ( Affine Transformation ) 和一个非线性变换. 上述两式也可以合并写为: 或者
仿射变换:又称仿射映射,是指在几何中,一个向量空间进行一次线性变换并接上一个平移,变换为另一个向量空间.
这样, 前馈神经网络可以通过逐层的信息传递,得到网络最后的输出 .整个网络可以看作一个复合函数 ,将向量 作为第 1 层的输入 .将第 层的输出 作为整个函数的输出.
其中 表示网络中所有层的连接权重和偏置.
通用近似定理(Universal Approximation Theorem ) [Cy- benko, 1989; Hornik et al., 1989]: 令 是一个非常数、有界、单调递增的连续函数, 是一个 维的单位超立方体 是定义在 上的连续函数集合.对于任意给定的一个函数 ,存在一个整数 ,和一组实数 以及实数向量 ,以至于我 们可以定义函数 作为函数 的近似实现,即 其中 是一个很小的正数.
二. 反向传播
假设采用随机梯度下降进行神经网络参数学习,给定一个样本 ,将其输入到神经网络模型中,得到网络输出为 .假设损失函数为 ,要进行参数学习就需要计算损失函数关于每个参数的导数.
不失一般性,对第 层中的参数 和 计算偏导数.因为 的计算 涉及向量对矩阵的微分,十分繁銷,因此我们先计算 关于参数矩阵中每个元素的偏导数 .根据链式法则, 上式中的第二项都是目标函数关于第 层的神经元 的偏导数,称为误差项,可以一次计算得到,这样我们只需要计算三个偏导数, 分 别为 和 下面分别来计算这三个偏导数.
-
计算偏导数 因 ,偏导数 其中 为权重矩阵 的第 行, 表示第 个元素为 ,其余为 0 的行向量.
-
计算偏导数 因为 和 的函数关系为 ,因此偏导数 为 的单位矩阵.
-
计算偏导数 偏导数 表示第 层神经元对最终损失 的影响,也反映了最终损失对第 层神经元的敏感程度,因此一般称为第 层神经元的误差项,用 来表示.
误差项 也间接反映了不同神经元对网络能力的贡献程度,从而比较好地解决 了贡献度分配问题 ( Credit Assignment Problem, CAP ).
根据 ,有 根据 ,其中 为按位计算的函数,因此有 因此,根据链式法则,第 层的误差项为 其中 是向量的点积运算符,表示每个元素相乘.
从上式可以看出,第 层的误差项可以通过第 层的误差项计算得到,这就是误差的反向传播 ( BackPropagation, BP ).反向传播算法的含义是: 第 层的一个神经元的误差项 或敏感性 是所有与该神经元相连的第 层 的神经元的误差项的权重和.然后,再乘上该神经元激活函数的梯度.
在计算出上面三个偏导数之后,最初的式子可以写为 其中 相当于向量 和向量 的外积的第 个元素.上式可以进一步写为 因此, 关于第 层权重 的梯度为
同理, 关于第 层偏置 的梯度为 在计算出每一层的误差项之后,我们就可以得到每一层参数的梯度.因此,使用误差反向传播算法的前馈神经网络训练过程可以分为以下三步:
- 前馈计算每一层的净输入 和激活值 ,直到最后一层;
- 反向传播计算每一层的误差项 ;
- 计算每一层参数的偏导数,并更新参数.
三. 自动微分
自动微分(Automatic Differentiation,AD).
为简单起见,这里以一个神经网络中常见的复合函数的例子来说明自动微分的过程.令复合函数 为 其中 为输入标量, 和 分别为权重和偏置参数.
首先,我们将复合函数 分解为一系列的基本操作,并构成一个计算图 ( Computational Graph ).计算图是数学运算的图形化表示.计算图中的每个非叶子节点表示一个基本操作,每个叶子节点为一个输入变量或常量.下图给 出了当 时复合函数 的计算图,其中连边上的红色数字表示前向计算时复合函数中每个变量的实际取值.
从计算图上可以看出,复合函数 由 6 个基本函数 组成.如下图所示,每个基本函数的导数都十分简单,可以通过规则来实现.
整个复合函数 关于参数 和 的导数可以通过计算图上的节点 与参数 和 之间路径上所有的导数连乘来得到,即 以 为例,当 时,可以得到 如果函数和参数之间有多条路径,可以将这多条路径上的导数再进行相加,得到最终的梯度.
按照计算导数的顺序,自动微分可以分为两种模式:前向模式和反向模式.
前向模式 前向模式是按计算图中计算方向的相同方向来递归地计算梯度.以 为例,当 时,前向模式的累积计算顺序如下: 反向模式 反向模式是按计算图中计算方向的相反方向来递归地计算梯度.以 为例,当 时,反向模式的累积计算顺序如下: 前向模式和反向模式可以看作应用链式法则的两种梯度累积方式.从反向模式的计算顺序可以看出,反向模式和反向传播的计算梯度的方式相同.对于一般的函数形式 ,前向模式需要对每一个输入变量都进行一遍遍历,共需要 遍.而反向模式需要对每一个输出都进行一个遍历,共需要 遍.当 时,反向模式更高效.在前馈神经网络的参数学习中,风险函数为 ,输出为标量,因此采用反向模式为最有效的计算方式,只需要一遍计算.
静态计算图和动态计算图计算图按构建方式可以分为静态计算图(StaticCom-putational Graph)和动态计算图(Dynamic Computational Graph).在目前深度学习框架里,Theano和Ten-sorflow采用的是静态计算图,而DyNet、Chainer和PyTorch采用的是动态计算图.Tensorflow 2.0也支持了动态计算图.静态计算图是在编译时构建计算图,计算图构建好之后在程序运行时不能改变,而动态计算图是在程序运行时动态构建.两种构建方式各有优缺点.静态计算图在构建时可以进行优化,并行能力强,但灵活性比较差.动态计算图则不容易优化,当不同输入的网络结构不一致时,难以并行计算,但是灵活性比较高.
四. 卷积神经网络
相关介绍还可看知乎:深度学习中不同类型卷积的综合介绍和原文
卷积神经网络(Convolutional Neural Network,CNN或ConvNet)是一种具有局部连接、权重共享等特性的深层前馈神经网络.卷积神经网络最早主要是用来处理图像信息.在用全连接前馈网络来处理图像时,会存在以下两个问题:
- 参数太多:如果输入图像大小为 100×100×3(即图像高度为 100 ,宽度为 100 以及RGB 3 个颜色通道),在全连接前馈网络中,第一个隐藏层的每个神经元到输入层都有 100 × 100 × 3 = 30000 个互相独立的连接,每个连接都对应一个权重参数.随着隐藏层神经元数量的增多,参数的规模也会急剧增加.这会导致整个神经网络的训练效率非常低,也很容易出现过拟合.
- 局部不变性特征:自然图像中的物体都具有局部不变性特征,比如尺度缩放、平移、旋转等操作不影响其语义信息.而全连接前馈网络很难提取这些局部不变性特征,一般需要进行数据增强来提高性能.
卷积神经网络是受生物学上感受野机制的启发而提出的.感受野(Recep-tive Field)机制主要是指听觉、视觉等神经系统中一些神经元的特性,即神经元只接受其所支配的刺激区域内的信号.在视觉神经系统中,视觉皮层中的神经细胞的输出依赖于视网膜上的光感受器.视网膜上的光感受器受刺激兴奋时,将神经冲动信号传到视觉皮层,但不是所有视觉皮层中的神经元都会接受这些信号.一个神经元的感受野是指视网膜上的特定区域,只有这个区域内的刺激才能够激活该神经元.
目前的卷积神经网络一般是由卷积层、汇聚层和全连接层交叉堆叠而成的前馈神经网络.卷积神经网络有三个结构上的特性:局部连接、权重共享以及汇聚.这些特性使得卷积神经网络具有一定程度上的平移、缩放和旋转不变性.和前馈神经网络相比,卷积神经网络的参数更少.卷积神经网络主要使用在图像和视频分析的各种任务(比如图像分类、人脸识别、物体识别、图像分割等)上,其准确率一般也远远超出了其他的神经网络模型.近年来卷积神经网络也广泛地应用到自然语言处理、推荐系统等领域.
4.1 卷积
4.1.1 卷积的定义
卷积(Convolution),也叫褶积,是分析数学中一种重要的运算.在信号处理或图像处理中,经常使用一维或二维卷积.
4.1.1.1 一维卷积
一维卷积经常用在信号处理中,用于计算信号的延迟累积.假设一个信号发生器每个时刻 产生一个信号 ,其信息的衰减率为 ,即在 个时间步长后,信息为原来的 倍.假设 ,那么在时刻 收到的信号 为当前时刻产生的信息和以前时刻延迟信息的叠加. 我们把 称为滤波器 ( Filter ) 或卷积核 ( Convolution Kernel ).假设滤波器长度为 ,它和一个信号序列 的卷积为 为了简单起见,这里假设卷积的输出 的下标 从 开始.
信号序列 和滤波器 的卷积定义为 其中 表示卷积运算.一般情况下滤波器的长度 远小于信号序列 的长度.
我们可以设计不同的滤波器来提取信号序列的不同特征.比如,当令滤波器 时,卷积相当于信号序列的简单移动平均 窗口大小为 ;当令滤波器 时,可以近似实现对信号序列的二阶微分,即 下图给出了两个滤波器的一维卷积示例.可以看出,两个滤波器分别提取了输入序列的不同特征.滤波器 可以检测信号序列中的低频信息,而滤波器 可以检测信号序列中的高频信息.(高低频指信号变化的强烈程度)
4.1.1.2 二维卷积
卷积也经常用在图像处理中.因为图像为一个二维结构,所以需要将一维卷积进行扩展.给定一个图像 和一个滤波器 ,一般 ,其卷积为 为了简单起见,这里假设卷积的输出 的下标 从 开始.
输入信息 和滤波器 的二维卷积定义为 其中*表示二维卷积运算. 下图给出了二维卷积示例.
在图像处理中常用的均值滤波 ( Mean Filter ) 就是一种二维卷积,将当前位置的像素值设为滤波器窗口中所有像素的平均值,即 .
在图像处理中,卷积经常作为特征提取的有效方法.一幅图像在经过卷积操作后得到结果称为特征映射(Feature Map).下图给出在图像处理中几种常用的滤波器,以及其对应的特征映射.图中最上面的滤波器是常用的高斯滤波器,可以用来对图像进行平滑去噪;中间和最下面的滤波器可以用来提取边缘特征.
4.1.2 互相关
在机器学习和图像处理领域,卷积的主要功能是在一个图像 ( 或某种特征 ) 居滑动一个卷积核 ( 即滤波器 通过卷积操作得到一组新的特征.在计算卷积的过程中,需要进行卷积核翻转.在具体实现上,一般会以互相关操作来代替卷积,从而会减少一些不必要的操作或开销.互相关 ( Cross-Correlation ) 是一个 衡量两个序列相关性的函数,通常是用滑动窗口的点积计算来实现.给定一个图像 和卷积核 ,它们的互相关为 和公式 (4.1) 对比可知,互相关和卷积的区别仅仅在于卷积核是否进行翻转.因此互相关也可以称为不翻转卷积.
公式 (4.2) 可以表述为 其中 表示互相关运算, 表示旋转 180 度, 为输出 矩阵.
在神经网络中使用卷积是为了进行特征抽取,卷积核是否进行翻转和其特征抽取的能力无关.特别是当卷积核是可学习的参数时,卷积和互相关在能力上是等价的.因此,为了实现上 (或描述上 ) 的方便起见,我们用互相关来代替卷积.事实上,很多深度学习工具中卷积操作其实都是互相关操作.
4.1.3 卷积的变种
在卷积的标准定义基础上,还可以引入卷积核的滑动步长和零填充来增加卷积的多样性,可以更灵活地进行特征抽取.
- 步长(Stride)是指卷积核在滑动时的时间间隔.下图左给出了步长为2的卷积示例.(步长也可以小于1,即微步卷积)
- 零填充(Zero Padding)是在输入向量两端进行补零.下图右给出了输入的两端各补一个零后的卷积示例.
假设卷积层的输入神经元个数为 ,卷积大小为 ,步长为 ,在输入两端各填补 个 0 ( zero padding ) ,那么该卷积层的神经元数量为
一般常用的卷积有以下三类:
- 窄卷积 ( Narrow Convolution ) : 步长 ,两端不补零 ,卷积后输出长度为
- 宽卷积 ( Wide Convolution ) : 步长 ,两端补零 ,卷积后输出长度 . (3) 等宽卷积 ( Equal-Width Convolution ) 步长 ,两端补零 ,卷积后输出长度 .上图右就是一个等宽卷积示例.
4.1.4 卷积的数学性质
4.1.4.1 交换性
如果不限制两个卷积信号的长度,真正的翻转卷积是具有交换性的,即 .对于互相关的“卷积”,也同样具有一定的“交换性”.
我们先介绍宽卷积 ( Wide Convolution ) 的定义.给定一个二维图像 和一个二维卷积核 , 对图像 进行零填充,两端各补 和 个零,得到全填充 ( Full Padding 的图像 .图像 和卷积核 的宽卷积定义为 其中 表示宽卷积运算.当输入信息和卷积核有固定长度时,它们的宽卷积依然具有交换性,即 其中 表示旋转 180 度.
4.1.4.2 导数
假设 ,其中 ,函数 为一个标量函数,则 从上式可以看出, 关于 的偏导数为 和 的卷积 同理得到, 其中当 ,或 ,或 ,或 时, .即相当于对 进行了 的零填充.
从上式可以看出, 关于 的偏导数为 和 的宽卷积.上式中的卷积是真正的卷积而不是互相关,为了一致性,我们用互相关的“卷积”,即 其中 表示旋转 180 度.
4.2 卷积神经网络
卷积神经网络一般由卷积层、汇聚层和全连接层构成.
4.2.1 用卷积代替全连接
在全连接前馈神经网络中,如果第 层有 个神经元,第 层有 个神经元,连接边有 个,也就是权重矩阵有 个参数.当 和 都很大时,权重矩阵的参数非常多,训练的效率会非常低.
如果采用卷积来代替全连接,第 层的净输入 为第 层活性值 和卷积核 的卷积,即 其中卷积核 为可学习的权重向量, 为可学习的偏置.
根据卷积的定义,卷积层有两个很重要的性质 :
局部连接 在卷积层 假设是第 层 中的每一个神经元都只和下一层 第 层 ) 中某个局部窗口内的神经元相连,构成一个局部连接网络.如下图所示,卷积层和下一层之间的连接数大大减少,由原来的 个连接变为 个连接, 为卷积核大小.
权重共享 从上式可以看出,作为参数的卷积核 对于第 层的所有的神经元都是相同的.如下图中,所有的同颜色连接上的权重是相同的.权重共享可以理解为一个卷积核只捕捉输入数据中的一种特定的局部特征.因此,如果要提取多种特征就需要使用多个不同的卷积核.
由于局部连接和权重共享,卷积层的参数只有一个 维的权重 和 1 维的偏置 ,共 个参数.参数个数和神经元的数量无关.此外,第 层的神经 元个数不是任意选择的,而是满足 .
4.2.2 卷积层
卷积层的作用是提取一个局部区域的特征,不同的卷积核相当于不同的特征提取器.上一节中描述的卷积层的神经元和全连接网络一样都是一维结构.由于卷积网络主要应用在图像处理上,而图像为二维结构,因此为了更充分地利用图像的局部信息,通常将神经元组织为三维结构的神经层,其大小为高度 宽 度 深度 ,由 个 大小的特征映射构成.
特征映射 ( Feature Map ) 为一幅图像 ( 或其他特征映射 ) 在经过卷积提取到的特征,每个特征映射可以作为一类抽取的图像特征.为了提高卷积网络的表示能力,可以在每一层使用多个不同的特征映射,以更好地表示图像的特征.
在输入层,特征映射就是图像本身.如果是灰度图像,就是有一个特征映射,输入层的深度 ;如果是彩色图像,分别有 三个颜色通道的特征映射,输入层的深度 .
不失一般性,假设一个卷积层的结构如下:
- 输入特征映射组: 为三维张量 ( Tensor ), 其中每个切 片 ( Slice ) 矩阵 为一个输入特征映射, ;
- 输出特征映射组: 为三维张量,其中每个切片矩阵 为一个输出特征映射, ;
- 卷积核: 为四维张量,其中每个切片矩阵 为一个二维卷积核, .
下图给出卷积层的三维结构表示.
为了计算输出特征映射 ,用卷积核 分别对输入特征映射 进行卷积,然后将卷积结果相加,并加上一个标量偏置 得到卷积层的净输入 ,再经过非线性激活函数后得到输出特征映射 . 其中 为三维卷积核, 为非线性激活函数,一般用 函数.
整个计算过程如下图所示.如果希望卷积层输出 个特征映射,可以将上述计算过程重复 次,得到 个输出特征映射 .
在输入为 ,输出为 的卷积层中,每一个输出特征映射都需要 个卷积核以及一个偏置.假设每个卷积核的大小为 ,那么 共需要 个参数.
4.2.3 汇聚层
汇聚层 ( Pooling Layer ) 也叫子采样层 ( Subsampling Layer ) ,其作用是进行特征选择,降低特征数量,从而减少参数数量.
卷积层虽然可以显著减少网络中连接的数量,但特征映射组中的神经元个数并没有显著减少.如果后面接一个分类,分类器的输入维数依然很高,很容易出现过拟合.为了解决这个问题,可以在卷积层之后加上一个汇聚层,从而降低特征维数,避免过拟合.
假设汇聚层的输入特征映射组为 ,对于其中每一个特征映射 ,将其划分为很多区域 ,这些区域可以重叠,也可以不重叠,汇聚 ( Pooling ) 是指对每个区域进行下采样 ( Down Sampling ) 得到一个值,作为这个区域的概括.
常用的汇聚函数有两种:
- 最大汇聚 ( Maximum Pooling 或 Max Pooling ) :对于一个区域 ,选择这个区域内所有神经元的最大活性值作为这个区域的表示,即 其中 为区域 内每个神经元的活性值.
- 平均汇聚 ( Mean Pooling ) :一般是取区域内所有神经元活性值的平均值,即 对每一个输入特征映射 的 个区域进行子采样,得到汇聚层的输出特征映射 .
下图给出了采样最大汇聚进行子采样操作的示例.可以看出,汇聚层不但可以有效地减少神经元的数量,还可以使得网络对一些小的局部形态改变保持不变性,并拥有更大的感受野.
目前主流的卷积网络中,汇聚层仅包含下采样操作.但在早期的一些卷积网络 比如 LeNet-5 ) 中,有时也会在汇聚层使用非线性激活函数,比如 其中 为汇聚层的输出, 为非线性激活函数, 和 为可学习的标量权重和偏置.
典型的汇聚层是将每个特征映射划分为 大小的不重叠区域,然后使用最大汇聚的方式进行下采样.汇聚层也可以看作一个特殊的卷积层,卷积核大小 为 ,步长为 ,卷积核为 函数或 mean 函数.过大的采样区域会急剧减少神经元的数量,也会造成过多的信息损失.
4.2.4 卷积网络的整体结构
一个典型的卷积网络是由卷积层、汇聚层、全连接层交叉堆叠而成.目前常用的卷积网络整体结构如下图所示.一个卷积块为连续 个卷积层和 个汇聚层 通常设置为 为 0 或 1 ).一个卷积网络中可以堆叠 个连续的卷积块,然后在后面接着 个全连接层 的取值区间比较大,比如 或者更大; 一般为 ).
目前,卷积网络的整体结构趋向于使用更小的卷积核 比如 和 以及更深的结构 比如层数大于 50 ).此外,由于卷积的操作性越来越灵活 ( 比如不同的步长 ),汇聚层的作用也变得越来越小,因此目前比较流行的卷积网络中,汇聚层的比例正在逐渐降低,趋向于全卷积网络.
4.3 参数学习(卷积网络的反向传播)
在卷积网络中,参数为卷积核中权重以及偏置.和全连接前软网络类似,卷积网络也可以通过误差反向传播算法来进行参数学习.
在全连接前馈神经网络中,梯度主要通过每一层的误差项 进行反向传播,并进一步计算每层参数的梯度.
在卷积神经网络中,主要有两种不同功能的神经层:卷积层和汇聚层.而参数为卷积核以及偏置,因此只需要计算卷积层中参数的梯度.
不失一般性,对第 层为卷积层,第 层的输入特征映射为 ,通过卷积计算得到第 层的特征映射净输入 .第 层的第 个特征映射净输入 其中 和 为卷积核以及偏置.第 层中共有 个卷积核和 个偏 置,可以分别使用链式法则来计算其梯度.
根据上式,损失函数 关于第 层的卷积核 的偏 导数为 其中 为损失函数关于第 层的第 个特征映射净输入 的偏导数.
同理可得,损失函数关于第 层的第 个偏置 的偏导数为 在卷积网络中,每层参数的梯度依赖其所在层的误差项 .
4.3.1 卷积神经网络的反向传播算法
卷积层和汇聚层中误差项的计算有所不同,因此我们分别计算其误差项.
汇聚层 当第 层为汇聚层时,因为汇聚层是下采样操作, 层的每个神经元的误差项 对应于第 层的相应特征映射的一个区域. 层的第 个特征映射中的每个神经元都有一条边和 层的第 个特征映射中的一个神经元相连.根据链式法则,第 层的一个特征映射的误差项 ,只需要将 层对应特征映射的误差项 进行上采样操作 和第 层的大小一样 ,再和 层特征映射的激活值偏导数逐元素相乘,就得到了 .
第 层的第 个特征映射的误差项 的具体推导过程如下: 其中 为第 层使用的激活函数导数,up 为上采样函数 ( up sampling ),与汇聚层中使用的下采样操作刚好相反.如果下采样是最大汇聚,误差项 中每个值会直接传递到上一层对应区域中的最大值所对应的神经元,该区域中其他神经元的误差项都设为 .如果下采样是平均汇聚,误差项 中每个值会被平均分配到上一层对应区域中的所有神经元上.
卷积层 当 层为卷积层时,假设特征映射净输入 ,其中 第 个特征映射净输入 其中 和 为第 层的卷积核以及偏置.第 层中共有 个卷积核和 个偏置.
第 层的第 个特征映射的误差项 的具体推导过程如下: 其中 为宽卷积.
4.4 几种典型的卷积神经网络
4.4.1 LeNet-5
LeNet-5[LeCun et al., 1998 ] 虽然提出的时间比较早,但它是一个非常成功的神经网络模型.基于 LeNet-5 的手写数字识别系统在 20 世纪 90 年代被美国很多银行使用,用来识别支票上面的手写数字. LeNet-5的网络结构如下图所示.
LeNet-5共有 7 层,接受输入图像大小为 ,输出对应 10 个类别的得分. LeNet-5中的每一层结构如下:
- C1 层是卷积层,使用 6 个 的卷积核,得到 6 组大小为 的特征映射.因此, 层的神经元数量为 , 可训练参数数量为 , 连接数为 (包括偏置在内,下同 ).
- S2层为汇聚层,采样窗口为 ,使用平均汇聚,并使用一个非线性函数.神经元个数为 , 可训练参数数量为 , 连接数为 .
- C3 层为卷积层. LeNet-5 中用一个连接表来定义输入和输出特征映 射之间的依赖关系,如图5.11所示,共使用 60 个 的卷积核,得到 16 组大 小为 的特征映射. 神经元数量为 , 可训练参数数量为 , 连接数为
- S4 层是一个汇聚层, 采样窗口为 , 得到 16 个 大小的特征映射, 可训练参数数量为 , 连接数为
- C5 层是一个卷积层, 使用 个 的卷积核, 得到 120 组大小为 的特征映射. 层的神经元数量为 120, 可训练参数数量为 , 连接数为
- F6层是一个全连接层,有 84 个神经元, 可训练参数数量为 1 ) 164. 连接数和可训练参数个数相同,为
- 输出层:输出层由 10 个径向基函数 ( Radial Basis Function, ) 组成.这里不再详述.
连接表 从公式可以看出, 卷积层的每一个输出特征映射都依赖于所有输入特征映射,相当于卷积层的输入和输出特征映射之间是全连接的关系. 实际上, 这种全连接关系不是必须的. 我们可以让每一个输出特征映射都依赖于少数几个输入特征映射. 定义一个连接表 ( Link Table ) 来描述输入和输出特征映射之间的连接关系. 在 LeNet-5 中, 连接表的基本设定如下图所示. C3层的第 0 -5 个特征映射依赖于S2层的特征映射组的每 3 个连续子集,第 6-11个特征映射依赖于 层的特征映射组的每 4 个连续子集, 第 12-14 个特征映射依赖于 S2 层的特征映射的每 4 个不连续子集,第 15 个特征映射依赖于 层的所有特征映射.
如果第 个输出特征映射依赖于第 个输入特征映射, 则 , 否则为 为其中 为 大小的连接表. 假设连接表 的非零个数为 , 每个卷积核的大小为 ,那么共需要 参数.
4.4.2 AlexNet
AlexNet[Krizhevsky et al., 2012 ] 是第一个现代深度卷积网络模型, 其首次使用了很多现代深度卷积网络的技术方法, 比如使用 进行并行训练, 采有了 作为非线性激活函数, 使用 Dropout 防止过拟合, 使用数据增强来提高 模型准确率等. AlexNet 赢得了 2012 年 ImageNet 图像分类竞赛的冠军.
AlexNet 的结构如下图所示,包括 5 个卷积层、3个汇聚层和 3 个全连接层(其中最后一层是使用 Softmax 函数的输出层).因为网络规模超出了当时的单个 GPU的内存限制, AlexNet 将网络拆为两半,分别放在两个 上, GPU间只 在某些层 比如第 3 层 进行通信.
AlexNet 的输入为 的图像,输出为 1000 个类别的条件概率,具体结构如下:
- 第一个卷积层,使用两个大小为 的卷积核,步长 , 零填充 , 得到两个大小为 的特征映射组.
- 第一个汇聚层,使用大小为 的最大汇聚操作,步长 , 得到两个 的特征映射组.
- 第二个卷积层,使用两个大小为 的卷积核,步长 , 零填充 , 得到两个大小为 的特征映射组.
- 第二个汇聚层,使用大小为 的最大汇聚操作,步长 , 得到两个大小为 的特征映射组.
- 第三个卷积层为两个路径的融合,使用一个大小为 的卷积核, 步长 , 零填充 , 得到两个大小为 的特征映射组.
- 第四个卷积层, 使用两个大小为 的卷积核,步长 , 零填充 ,得到两个大小为 的特征映射组.
- 第五个卷积层, 使用两个大小为 的卷积核,步长 , 零填充 , 得到两个大小为 的特征映射组.
- 第三个汇聚层,使用大小为 的最大汇聚操作,步长 , 得到两个大小为 的特征映射组.
- 三个全连接层,神经元数量分别为 4096,4096 和 1000 .此外, AlexNet 还在前两个汇聚层之后进行了局部响应归一化 ( Local Re'nonse Normalization.LRN ) 以增强模型的泛化能力.
4.4.3 Inception网络
在卷积网络中,如何设置卷积层的卷积核大小是一个十分关键的问题. 在 Inception 网络中, 一个卷积层包含多个不同大小的卷积操作, 称为Inception 模块. Inception 网络是由有多个 Inception 模块和少量的汇聚层堆叠而成.
Inception 模块同时使用 等不同大小的卷积核, 并将得到的特征映射在深度上拼接 ( 堆叠 ) 起来作为输出特征映射.
下图给出了v1版本的 Inception 模块结构, 采用了 4 组平行的特征抽取方 式, 分别为 的卷积和 的最大汇聚. 同时, 为了提高计算效 率,减少参数数量, Inception 模块在进行 的卷积之前、 的最大汇聚之后,进行一次 的卷积来减少特征映射的深度. 如果输入特征映射之间存在冗余信息, 的卷积相当于先进行一次特征抽取.
Inception 网络有多个版本, 其中最早的 Inception v1 版本就是非常著名的 GoogLeNet [Szegedy et al., 2015]. GoogLeNet 赢得了 2014年 ImageNet 图像分 类竞赛的冠军.
GoogLeNet 由 9 个 Inception v1 模块和 5 个汇聚层以及其他一些卷积层和全连接层构成, 总共为 22 层网络,如下图所示.
为了解决梯度消失问题, GoogLeNet 在网络中间层引入两个辅助分类器来加强监督信息.
Inception 网络有多个改进版本, 其中比较有代表性的有Inception v3网络[Szegedy et al., 2016]. Inception v3 网络用多层的小卷积核来替换大的卷积核,以减少计算量和参数量,并保持感受野不变. 具体包括 ) 使用两层 的卷积来替换 中的 的卷积; 2 ) 使用连续的 和 来替换 的卷积.
此外, Inception v3 网络同时也引入了标签平滑以及批量归一化等优化方法进行训练.
4.4.4 残差网络
残差网络 ( Residual Network, ResNet ) 通过给非线性的卷积层增加直连边 ( Shortcut Connection ) ( 也称为残差连接 ( Residual Connection ) ) 的方式来提高信息的传播效率.
假设在一个深度网络中,我们期望一个非线性单元 (可以为一层或多层的卷积层 去逼近一个目标函数为 .如果将目标函数拆分成两部分 : 恒等函数 ( Identity Function ) 和残差函数 ( Residue Function) . 根据通用近似定理,一个由神经网络构成的非线性单元有足够的能力来近似逼近原始目标函数或残差函数,但实际中后者更容易学习 [He et al., 2016].因此,原来的优化问题可以转换为:让非线性单元 去近似残差函数 , 并用 去逼近 .
下图给出了一个典型的残差单元示例.残差单元由多个级联的 ( 等宽 ) 卷积层和一个跨层的直连边组成,再经过 ReLU 激活后得到输出.
残差网络就是将很多个残差单元串联起来构成的一个非常深的网络.和残差网络类似的还有 Highway Network[Srivastava et al., 2015].
4.5 其他卷积方式
在第4.1.3节中介绍了一些卷积的变种,可以通过步长和零填充来进行不同的卷积操作.本节介绍一些其他的卷积方式.
4.5.1 转置卷积
我们一般可以通过卷积操作来实现高维特征到低维特征的转换.比如在一维卷积中,一个 5 维的输入特征,经过一个大小为 3 的卷积核,其输出为 3 维特征.如果设置步长大于 1 ,可以进一步降低输出特征的维数.但在一些任务中,我们需要将低维特征映射到高维特征,并且依然希望通过卷积操作来实现.
假设有一个高维向量为 和一个低维向量为 .如果用仿射变换 ( Affine Transformation ) 来实现高维到低维的映射, 其中 为转换矩阵.我们可以很容易地通过转置 来实现低维到高维的反向映射,即 需要说明的是,上两式并不是逆运算,两个映射只是形式上 的转置关系.
在全连接网络中,忽略激活函数,前向计算和反向传播就是一种转置关系.比如前向计算时,第 层的净输入为 ,反向传播时,第 层的误差项为 .
卷积操作也可以写为仿射变换的形式.假设一个 5 维向量 ,经过大小为 3 的卷积核 进行卷积,得到 3 维向量 .卷积操作可以写为 其中 是一个稀疏矩阵,其非零元素来自于卷积核 中的元素.
如果要实现 3 维向量 到 5 维向量 的映射,可以通过仿射矩阵的转置来实现,即 其中 表示旋转 180 度.
可以看出,从仿射变换的角度来看两个卷积操作 和 也是形式上的转置关系.因此,我们将低维特征映射到高维特征的卷积操作称为转置卷积 ( Transposed Convolution ) [Dumoulin et al., 2016],也称为反卷积 ( Deconvolution ) [Zeiler et al., 2011].
在卷积网络中,卷积层的前向计算和反向传播也是一种转置关系.
对一个 维的向量 ,和大小为 的卷积核,如果希望通过卷积操作来映射到更高维的向量,只需要对向量 进行两端补零 ,然后进行卷积,可以得到 维的向量.
转置卷积同样适用于二维卷积.下图给出了一个步长 ,无零填充 的二维卷积和其对应的转置卷积.
微步卷积 我们可以通过增加卷积操作的步长 来实现对输入特征的下采样操作,大幅降低特征维数.同样,我们也可以通过减少转置卷积的步长 来实现上采样操作,大幅提高特征维数.步长 的转置卷积也称为微步卷积 ( Fractionally-Strided Convolution ) [Long et al., 2015].为了实现微步卷积,我们可以在输入特征之间插入 0 来间接地使得步长变小.
如果卷积操作的步长为 ,希望其对应的转置卷积的步长为 ,需要在输入特征之间插入 个 0 来使得其移动的速度变慢.
以一维转置卷积为例, 对一个 维的向量 ,和大小为 的卷积核,通过对向量 进行两端补零 ,并且在每两个向量元素之间插入 个 0 ,然后进行步长为 1 的卷积,可以得到 维的向量.
下图给出了一个步长 ,无零填充 的二维卷积和其对应的转置卷积.
4.5.2 空洞卷积
对于一个卷积层,如果希望增加输出单元的感受野,一般可以通过三种方式实现: 1 )增加卷积核的大小; 2 ) 增加层数,比如两层 的卷积可以近似一层 卷积的效果 ; 3 ) 在卷积之前进行汇聚操作.前两种方式会增加参数数量,而第三种方式会丢失一些信息.
空洞卷积(Atrous Convolution ) 是一种不增加参数数量,同时增加输出单元感受野的一种方法,也称为膨胀卷积 ( Dilated Convolution ) [Chen et al. 2018; Yu et al., 2015.
空洞卷积通过给卷积核插入 “空洞”来变相地增加其大小.如果在卷积核的每两个元素之间插入 个空洞,卷积核的有效大小为 其中 称为膨胀率 ( Dilation Rate ).当 时卷积核为普通的卷积核.
下图给出了空洞卷积的示例.
4.6 总结和深入阅读
卷积神经网络是受生物学上感受野机制启发而提出的.1959 年,[Hubel et al., 1959] 发现在猫的初级视觉皮层中存在两种细胞:简单细胞和复杂细胞.这两种细胞承担不同层次的视觉感知功能 [Hubel et al., 1962].简单细胞的感受野是狭长型的,每个简单细胞只对感受野中特定角度 ( orientation ) 的光带敏感,而复杂细胞对于感受野中以特定方向 ( direction ) 移动的某种角度 ( ori- entation ) 的光带敏感.受此启发,福岛邦彦 ( Kunihiko Fukushima ) 提出了一种带卷积和子采样操作的多层神经网络:新知机 ( Neocognitron ) [Fukushima, 1980].但当时还没有反向传播算法,新知机采用了无监督学习的方式来训练.[LeCun et al., 1989 ] 将反向传播算法引入了卷积神经网络,并在手写体数字识别上取得了很大的成功 [LeCun et al., 1998].
AlexNet[Krizhevsky et al., 2012 ] 是第一个现代深度卷积网络模型,可以说是深度学习技术在图像分类上真正突破的开端. AlexNet 不用预训练和逐层训练,首次使用了很多现代深度网络的技术,比如使用 GPU 进行并行训练,采用了 ReLU作为非线性激活函数,使用 Dropout 防止过拟合,使用数据增强来提高模型准确率等.这些技术极大地推动了端到端的深度学习模型的发展.
在 AlexNet 之后,出现了很多优秀的卷积网络,比如 VGG 网络 [Simonyan et al., 2014]、Inception v1,v2, v4 网络 [Szegedy et al., 2015, 2016, 2017] 、残差网 络 [He et al., 2016] 等.
目前,卷积神经网络已经成为计算机视觉领域的主流模型.通过引入跨层的直连边,可以训练上百层乃至上千层的卷积网络.随着网络层数的增加,卷积层越来越多地使用 和 大小的小卷积核,也出现了一些不规则的卷积操作,比如空洞卷积 [Chen et al., 2018; Yu et al., 2015] 可变形卷积 [Dai et al., 2017] 等.网络结构也逐渐趋向于全卷积网络 ( Fully Convolutional Network, FCN ) [Long et al., 2015],减少汇聚层和全连接层的作用.
各种卷积操作的可视化示例可以参考 [Dumoulin et al., 2016].
循环神经网络
循环神经网络(Recurrent Neural Network,RNN)是一类具有短期记忆能力的神经网络.在循环神经网络中,神经元不但可以接受其他神经元的信息,也可以接受自身的信息,形成具有环路的网络结构.和前馈神经网络相比,循环神经网络更加符合生物神经网络的结构.循环神经网络已经被广泛应用在语音识别、语言模型以及自然语言生成等任务上.循环神经网络的参数学习可以通过随时间反向传播算法来学习.随时间反向传播算法即按照时间的逆序将错误信息一步步地往前传递.当输入序列比较长时,会存在梯度爆炸和消失问题,也称为长程依赖问题.为了解决这个问题,人们对循环神经网络进行了很多的改进,其中最有效的改进方式引入门控机制(Gating Mechanism).
此外,循环神经网络可以很容易地扩展到两种更广义的记忆网络模型:递归神经网络和图网络.
一. 循环神经网络
给定一个输入序列 ,循环神经网络通过下面公式更新带反馈边的隐藏层的活性值 :
其中 , 为一个非线性函数,可以是一个前馈网络.
给出了循环神经网络的示例,其中“延时器”为一个虚拟单元,记录神经元的最近一次(或几次)活性值.
从数学上讲,上式可以看成一个动力系统.因此,隐藏层的活性值 在很多文献上也称为状态(State)或隐状态(Hidden State).由于循环神经网络具有短期记忆能力,相当于存储装置,因此其计算能力十分强大.理论上,循环神经网络可以近似任意的非线性动力系统.前馈神经网络可以模拟任何连续函数,而循环神经网络可以模拟任何程序.
二. 简单循环网络
简单循环网络(Simple Recurrent Network,SRN)只有一个隐藏层.在一个两层的前馈神经网络中,连接存在于相邻的层与层之间,隐藏层的节点之间是无连接的.而简单循环网络增加了从隐藏层到隐藏层的反馈连接.
令向量 表示在时刻 时网络的输入, 表示隐藏层状态(即隐藏层神经元活性值),则 不仅和当前时刻的输入 相关,也和上一个时刻的隐藏层状态 相关.简单循环网络在时刻 的更新公式为
其中 为隐藏层的净输入, 为状态-状态权重矩阵, 为状态-输入权重矩阵, 为偏置向量, 是非线性激活函数,通常为Logistic函数或Tanh函数.也经常直接写为
下图给出了按时间展开的循环神经网络:
循环神经网络的通用近似定理
一个完全连接的循环网络是任何非线性动力系统的近似器.
定理:循环神经网络的通用近似定理[Haykin,2009]:如果一个完全连接的循环神经网络有足够数量的sigmoid型隐藏神经元,那么它可以以任意的准确率去近似任何一个非线性动力系统
其中 为每个时刻的隐状态, 是外部输入, 是可测的状态转换函数, 是连续输出函数,并且对状态空间的紧致性没有限制.
三. 参数学习
循环神经网络的参数可以通过梯度下降方法来进行学习.
以随机梯度下降为例,给定一个训练样本 ,其中 为长度是 的输入序列, 是长度为 的标签序列.即在每个时刻 ,都有一个监督信息 ,我们定义时刻 的损失函数为 其中 为第 时刻的输出, 为可微分的损失函数,比如交叉熵.那么整个序列的损失函数为 整个序列的损失函数 关于参数 的梯度为 即每个时刻损失 对参数 的偏导数之和.
循环神经网络中存在一个递归调用的函数 , 因此其计算参数梯度的方式和前馈神经网络不太相同.在循环神经网络中主要有两种计算梯度的方式:随 时间反向传播 ( BPTT ) 算法和实时循环学习 ( RTRL ) 算法.
3.1 随时间反向传播
随时间反向传播 ( BackPropagation Through Time, BPTT ) 算法的主要思想是通过类似前馈神经网络的错误反向传播算法 [Werbos, 1990]来计算梯度.
BPTT 算法将循环神经网络看作一个展开的多层前馈网络,其中“每一层”对应循环网络中的“每个时刻”.这样,循环神经网络就可以按照前馈网络中的反向传播算法计算参数梯度.在“展开”的前馈网络中,所有层的参数是共 享的,因此参数的真实梯度是所有“展开层”的参数梯度之和.
计算偏导数 先来计算第 时刻损失对参数 的偏导数 .
因为参数 和隐藏层在每个时刻 的净输入 有关,因此第 时刻的损失函数 关于参数 的梯度为: 其中 表示**“直接”偏导数**,即公式 中保持 不变,对 进行求偏导数,得到 其中 为第 时刻隐状态的第 维; 是除了第 行值为 外,其余都为 0 的行向量.
定义误差项 为第 时刻的损失对第 时刻隐藏神经层的净输入 的导数,则当 时 由上面三式得到 将上式写成矩阵形式为 下图给出了误差项随时间进行反向传播算法的示例.
参数梯度 由几式,得到整个序列的损失函数 关于参数 的梯度 同理可得, 关于权重 和偏置 的梯度为 计算复杂度 在BPTT算法中,参数的梯度需要在一个完整的“前向”计算和“反向”计算后才能得到并进行参数更新.
3.2 实时循环学习
与反向传播的 BPTT 算法不同的是, 实时循环学习 ( Real-Time Recurrent Learning, RTRL ) 是通过前向传播的方式来计算梯度 [Williams et al., 1995].
假设循环神经网络中第 时刻的状态 为 其关于参数 的偏导数为 其中 是除了第 行值为 外,其余都为 0 的行向量.
RTRL 算法从第 1 个时刻开始,除了计算循环神经网络的隐状态之外,还利用上式依次前向计算偏导数 .
这样,假设第 个时刻存在一个监督信息,其损失函数为 ,就可以同时计 算损失函数对 的偏导数 这样在第 时刻,可以实时地计算损失 关于参数 的梯度,并更新参数.参数 和 的梯度也可以同样按上述方法实时计算.
两种算法比较 RTRL算法和 BPTT 算法都是基于梯度下降的算法,分别通过前向模式和反向模式应用链式法则来计算梯度.在循环神经网络中,一般网络输出维度远低于输入维度,因此 BPTT 算法的计算量会更小,但是 BPTT 算法需要保存所有时刻的中间梯度,空间复杂度较高.RTRL算法不需要梯度回传,因此非常 适合用于需要在线学习或无限序列的任务中.
四. 长程依赖问题
循环神经网络在学习过程中的主要问题是由于梯度消失或爆炸问题,很难建模长时间间隔 ( Long Range ) 的状态之间的依赖关系.
在 BPTT 算法中,将公式(6.36)展开得到 如果定义 ,则 若 ,当 时, .当间隔 比较大时,梯度也变得很大,会造成系统不稳定,称为梯度爆炸问题 ( Gradient Exploding Problem ).
相反,若 ,当 时, .当间隔 比较大时,梯度也变得非常小,会出现和深层前馈神经网络类似的梯度消失问题 ( Vanishing Gradient Problem ).
要注意的是,在循环神经网络中的梯度消失不是说 的梯度消失了,而是 的梯度消失了 (当间隔 比较大时 .也就是说,参数 的更新主要靠当前时刻 的几个相邻状态 来更新,长距离的状态对参数 没有影响.
由于循环神经网络经常使用非线性激活函数为 Logistic 函数或 Tanh 函数作为非线性激活函数,其导数值都小于 1 ,并且权重矩阵 也不会太大,因此如果时间间隔 过大, 会趋向于 0 ,因而经常会出现梯度消失问题.
虽然简单循环网络理论上可以建立长时间间隔的状态之间的依赖关系,但是由于梯度爆炸或消失问题,实际上只能学习到短期的依赖关系.这样,如果时刻 的输出 依赖于时刻 的输入 ,当间隔 比较大时 ,简单神经网络很难建模这种长距离的依赖关系,称为长程依赖问题 ( Long-Term Dependencies Problem ).
4.1 改进方法
为了避免梯度爆炸或消失问题,一种最直接的方式就是选取合适的参数,同时使用非饱和的激活函数,尽量使得 ,这种方式需要足够的人 工调参经验,限制了模型的广泛应用.比较有效的方式是通过改进模型或优化方法来缓解循环网络的梯度爆炸和梯度消失问题.
梯度爆炸 一般而言,循环网络的梯度爆炸问题比较容易解决,一般通过权重衰减或梯度截断来避免.
权重衰减是通过给参数增加 或 范数的正则化项来限制参数的取值范围,从而使得 .梯度截断是另一种有效的启发式方法,当梯度的模大于一定阈值时,就将它截断成为一个较小的数.
梯度消失 梯度消失是循环网络的主要问题.除了使用一些优化技巧外,更有效的方式就是改变模型,比如让 ,同时令 为单位矩阵,即 其中 是一个非线性函数, 为参数. 公式(6.49)中, 和 之间为线性依赖关系,且权重系数为 1 ,这样就不存在梯度爆炸或消失问题.但是,这种改变也丢失了神经元在反馈边上的非线性 激活的性质,因此也降低了模型的表示能力.
为了避免这个缺点, 我们可以采用一种更加有效的改进策略: 这样 和 之间为既有线性关系,也有非线性关系,并且可以缓解梯度消失问题.但这种改进依然存在两个问题:
- 梯度爆炸问题:令 为在第 时刻函数 的输入,在计算公式(6.34)中的误差项 时,梯度可能会过大,从而导致梯度爆炸问题.
- 记忆容量 ( Memory Capacity ) 问题:随着 不断累积存储新的输入信息,会发生饱和现象.假设 为 Logistic 函数,则随着时间 的增长, 会变得越来越大,从而导致 变得饱和.也就是说,隐状态 可以存储的信息是有限的,随着记忆单元存储的内容越来越多,其丢失的信息也越来越多.
为了解决这两个问题,可以通过引入门控机制来进一步改进模型.
五. 基于门控的循环神经网络
为了改善循环神经网络的长程依赖问题,一种非常好的解决方案是在公式(6.50)的基础上引入门控机制来控制信息的累积速度,包括有选择地加入新的信息,并有选择地遗忘之前累积的信息. 这一类网络可以称为基于门控的循环神经网络 ( Gated RNN ) . 本节中,主要介绍两种基于门控的循环神经网络:长短期记忆网络和门控循环单元网络.
5.1 长短期记忆网络
长短期记忆网络 ( Long Short-Term Memory Network, LSTM ) [Gers et al. 2000; Hochreiter et al., 1997] 是循环神经网络的一个变体,可以有效地解决简单 循环神经网络的梯度爆炸或消失问题.
在公式 的基础上,LSTM 网络主要改进在以下两个方面:
新的内部状态 LSTM 网络引入一个新的内部状态 ( internal state ) 专门进行线性的循环信息传递,同时 非线性地 输出信息给隐藏层的外部状态 .内部状态 通过下面公式计算; 其中 和 为三个门 gate 来控制信息传递的路径; 为向量元素乘积 为上一时刻的记忆单元 是通过非线性函数得到的候选状态: 在每个时刻 网络的内部状态 记录了到当前时刻为止的历史信息.
门控机制 在数字电路中,门 gate 为一个二值变量 代表关闭状态,不许任何信息通过 代表开放状态,允许所有信息通过
LSTM 网络引入门控机制 ( Gating Mechanism ) 来控制信息传递的路径. 公式 (4.1)中三个 "门"分别为输入门 遗忘门 和输出门 .这二个门的作用为
- 遗忘门 控制上一个时刻的内部状态 需要遗忘多少信息.
- 输入门 控制当前时刻的候选状态 有多少信息需要保存.
- 输出门 控制当前时刻的内部状态 有多少信息需要输出给外部状态
当 时,记忆单元将历史信息清空,并将候选状态向量 写入.但此时记忆单元 依然和上一时刻的历史信息相关.当 时,记忆单元将复制上一时刻的内容,不写入新的信息.
LSTM 网络中的“门”是一种“软”门,取值在 之间,表示以一定的比例允许信息通过.三个门的计算方式为: 其中 为 Logistic 函数,其输出区间为 为当前时刻的输入, 为上一时刻的外部状态.
下图给出了 LSTM 网络的循环单元结构,其计算过程为 ) 首先利用上一时刻的外部状态 和当前时刻的输入 ,计算出三个门,以及候选状态 ) 结合遗忘门 和输入门 来更新记忆单元 ) 结合输出门 ,将内部状态的信息传递给外部状态 .
通过 循环单元,整个网络可以建立较长距离的时序依赖关系.公式可以简洁地描述为 其中 为当前时刻的输入, 和 为网络参数.
记忆 循环神经网络中的隐状态 存储了历史信息,可以看作一种记忆 ( Mem- ory ).在简单循环网络中,隐状态每个时刻都会被重写,因此可以看作一种短期记忆 ( Short-Term Memory ).在神经网络中,长期记忆 ( Long-Term Memory ) 可以看作网络参数,隐含了从训练数据中学到的经验,其更新周期要远远慢于短期记忆.而在 LSTM 网络中,记忆单元 可以在某个时刻捕捉到某个关键信息,并有能力将此关键信息保存一定的时间间隔.记忆单元 中保存信息的生命周期要长于短期记忆 ,但又远远短于长期记忆,因此称为长短期记忆 ( Long Short-Term Memory ).
一般在深度网络参数学习时,参数初始化的值一般都比较小.但是在训练 LSTM 网络时,过小的值会使得遗忘门的值比较小.这意味着前一时刻的信息大部分都丢失了,这样网络很难捕捉到长距离的依赖信息.并且相邻时间间隔的梯度会非常小,这会导致梯度弥散问题.因此遗忘的参数初始值一般都设得比较大,其偏置向量 设为 1 或2.
5.2 LSTM网络的各种变体
目前主流的 LSTM 网络用三个门来动态地控制内部状态应该遗忘多少历史信息,输入多少新信息,以及输出多少信息.我们可以对门控机制进行改进并获得 LSTM 网络的不同变体.
无遗忘门的 LSTM 网络 [Hochreiter et al., 1997] 最早提出的 LSTM 网络是没有遗忘门的,其内部状态的更新为 如之前的分析,记忆单元 会不断增大.当输入序列的长度非常大时,记忆单元的容量会饱和,从而大大降低 LSTM 模型的性能. peephole 连接 另外一种变体是三个门不但依赖于输入 和上一时刻的隐状态 ,也依赖于上一个时刻的记忆单元 ,即 其中 和 为对角矩阵.
耦合输入门和遗忘门 LSTM 网络中的输入门和遗忘门有些互补关系,因此同时用两个门比较冗余.为了减少 LSTM 网络的计算复杂度,将这两门合并为一个 门.令 , 内部状态的更新方式为
5.3 门控循环单元网络
门控循环单元 ( Gated Recurrent Unit, GRU ) 网络 [Cho et al., 2014; Chung et al., 2014] 是一种比 LSTM 网络更加简单的循环神经网络.
GRU 网络引入门控机制来控制信息更新的方式.和 LSTM 不同,GRU 不引入额外的记忆单元, GRU 网络也是在公式 的基础上引入一个更新门 ( Up- date Gate ) 来控制当前状态需要从历史状态中保留多少信息 ( 不经过非线性变换 ) ,以及需要从候选状态中接受多少新信息,即 其中 为更新门 在 LSTM 网络中,输入门和遗忘门是互补关系,具有一定的咒余性.GRU 网络直接使用一个门来控制输入和遗忘之间的平衡.当 时,当前状态 和前一时刻的状态 之间为非线性函数关系;当 时, 和 之间为线性函数关系.
在 GRU 网络中,函数 的定义为 其中 表示当前时刻的候选状态, 为重置门 ( Reset Gate ) 用来控制候选状态 的计算是否依赖上一时刻的状态 . 当 时,候选状态 只和当前输入 相关,和历史 状态无关.当 时,候选状态 和当前输入 以及历史状态 相关,和简单循环网络一致.
综上,GRU网络的状态更新方式为 可以看出,当 时,GRU 网络退化为简单循环网络;若 时,当前状态 只和当前输入 相关,和历史状态 无关. 当 时,当前状态 等于上一时刻状态 ,和当前输入 无关.
下图给出了GRU 网络的循环单元结构.
六. 深层循环神经网络
如果将深度定义为网络中信息传递路径长度的话,循环神经网络可以看作既“深”又“浅”的网络.一方面来说,如果我们把循环网络按时间展开,长时间间隔的状态之间的路径很长,循环网络可以看作一个非常深的网络.从另一方面来说,如果同一时刻网络输入到输出之间的路径 ,这个网络是非常浅的.
因此,我们可以增加循环神经网络的深度从而增强循环神经网络的能力.增加循环神经网络的深度主要是增加同一时刻网络输入到输出之间的路径 ,比如增加隐状态到输出 ,以及输入到隐状态 之间的路径的 深度.
6.1 堆叠循环神经网络
一种常见的增加循环神经网络深度的做法是将多个循环网络堆叠起来,称为堆叠循环神经网络 ( Stacked Recurrent Neural Network, SRNN ) .一个堆叠的简单循环网络 ( Stacked SRN ) 也称为循环多层感知器 ( Recurrent Multi-Layer Perceptron, RMLP ) [Parlos et al., 1991].
下图给出了按时间展开的堆肯循环神经网络.第 层网络的输入是第 层网络的输出.我们定义 为在时刻 时第 层的隐状态 其中 和 为权重矩阵和偏置向量,.
6.2 双向循环神经网络
在有些任务中,一个时刻的输出不但和过去时刻的信息有关,也和后续时刻的信息有关.比如给定一个句子,其中一个词的词性由它的上下文决定,即包含左右两边的信息.因此,在这些任务中,我们可以增加一个按照时间的逆序来传递信息的网络层,来增强网络的能力.
双向循环神经网络 ( Bidirectional Recurrent Neural Network, Bi-RNN ) 由两层循环神经网络组成,它们的输入相同,只是信息传递的方向不同.
假设第 1 层按时间顺序,第 2 层按时间逆序,在时刻 时的隐状态定义为 和 , 则 其中 为向量拼接操作.
下图给出了按时间展开的双向循环神经网络.
七. 扩展到图结构
如果将循环神经网络按时间展开,每个时刻的隐状态 看作一个节点,那么这些节点构成一个链式结构,每个节点 都收到其父节点的消息(Message),更新自己的状态,并传递给其子节点.而链式结构是一种特殊的图结构,我们可以比较容易地将这种消息传递 ( Message Passing ) 的思想扩展到任意的图结 构上.
7.1 递归神经网络
递归神经网络 ( Recursive Neural Network, ) 是循环神经网络在有向无循环图上的扩展 [Pollack, 1990].递归神经网络的一般结构为树状的层次结构,如下图左所示.
以上图左中的结构为例,有三个隐藏层 和 , 其中 由两个输入层 和 计算得到, 由另外两个输入层 和 计算得到, 由两个隐藏层 和 计算得到.
对于一个节点 ,它可以接受来自父节点集合 中所有节点的消息,并更新自己的状态. 其中 表示集合 中所有节点状态的拼接, 是一个和节点位置无关的非线性函数,可以为一个单层的前馈神经网络.比如上图左所示的递归神经网络具体可以写为 其中 表示非线性激活函数, 和 是可学习的参数.同样,输出层 可以为一个分类器,比如 其中 为分类器, 和 为分类器的参数.当递归神经网络的结构退化为线性序列结构 (上图右) 时,递归神经网络就等价于简单循环网络.
递归神经网络主要用来建模自然语言句子的语义[Socher et al., 2011,2013].给定一个句子的语法结构 ( 一般为树状结构 ),可以使用递归神经网络来按照句法的组合关系来合成一个句子的语义.句子中每个短语成分又可以分成一些子 成分,即每个短语的语义都可以由它的子成分语义组合而来,并进而合成整句的语义.
同样,我们也可以用门控机制来改进递归神经网络中的长距离依赖问题,比如树结构的长短期记忆模型 ( Tree-Structured LSTM ) [Tai et al., 2015; Zhu et al., 2015 ] 就是将 LSTM 模型的思想应用到树结构的网络中,来实现更灵活的组合函数.
7.2 图神经网络
在实际应用中,很多数据是图结构的,比如知识图谱、社交网络、分子网络等.而前馈网络和反馈网络很难处理图结构的数据.
图神经网络 ( Graph Neural Network, GNN ) 是将消息传递的思想扩展到图结构数据上的神经网络.
对于一个任意的图结构 ,其中 表示节点集合, 表示边集合.每条边表示两个节点之间的依赖关系.节点之间的连接可以是有向的,也可以是无向的.图中每个节点 都用一组神经元来表示其状态 , 初始状态可以为节点 的输入特征 .每个节点可以收到来自相邻节点的消息,并更新自己的状态. 其中 表示节点 的邻居, 表示在第 时刻节点 收到的信息, 为 边 上的特征.
上式是一种同步的更新方式,所有的结构同时接受信息并更新自己的状态.而对于有向图来说,使用异步的更新方式会更有效率,比如循环神经网络或递归神经网络,在整个图更新 次后,可以通过一个读出函数 ( Readout Function ) 来得到整个网络的表示:
八. 总结和深入阅读
循环神经网络可以建模时间序列数据之间的相关性.和延时神经网络[Lang et al., 1990; Waibel et al., 1989] 以及有外部输入的非线性自回归模型[Leontaritis et al., 1985 ]相比,循环神经网络可以更方便地建模长时间间隔的相关性.
常用的循环神经网络的参数学习算法是 BPTT算法 [Werbos, 1990],其计算时间和空间要求会随时间线性增长.为了提高效率,当输入序列的长度比较大时,可以使用带截断 ( truncated ) 的 BPTT算法[Williams et al., 1990],只计算固定时间间隔内的梯度回传.
一个完全连接的循环神经网络有着强大的计算和表示能力,可以近似任何非线性动力系统以及图灵机,解决所有的可计算问题.然而由于梯度爆炸和梯度消失问题,简单循环网络存在长期依赖问题[Bengio et al., 1994; Hochreiter et al., 2001].为了解决这个问题,人们对循环神经网络进行了很多的改进,其中最有效的改进方式为引入门控机制,比如 LSTM 网络 [Gers et al., 2000; Hochreiter et al., 1997]和GRU网络[Chung et al., 2014].当然还有一些其他方法,比如时钟循环神经网络 ( Clockwork RNN ) [Koutnik et al., 2014]、乘法RNN[Sutskever et al., 2011; Wu et al., 2016] 以及引入注意力机制等.
LSTM 网络是目前为止最成功的循环神经网络模型,成功应用在很多领域,比如语音识别、机器翻译 [Sutskever et al., 2014] 语音模型以及文本生成. LSTM 网络通过引入线性连接来缓解长距离依赖问题.虽然 LSTM 网络取得了很大的 成功,其结构的合理性一直受到广泛关注.人们不断尝试对其进行改进来寻找最优结构,比如减少门的数量、提高并行能力等.关于 LSTM 网络的分析可以参考文献 [Greff et al., 2017; Jozefowicz et al., 2015; Karpathy et al., 2015].
LSTM 网络的线性连接以及门控机制是一种十分有效的避免梯度消失问题的方法.这种机制也可以用在深层的前馈网络中,比如残差网络 [He et al., 2016] 和高速网络[Srivastava et al., 2015] 都通过引入线性连接来训练非常深的卷积网络.对于循环神经网格,这种机制也可以用在非时间维度上,比如 Gird LSTM 网络 [Kalchbrenner et al., 2015] 、Depth Gated RNN[Chung et al., 2015 等.
此外,循环神经网络可以很容易地扩展到更广义的图结构数据上,称为图网络[Scarselli et al., 2009].递归神经网络是一种在有向无环图上的简单的图网络.图网络是目前新兴的研究方向,还没有比较成熟的网络模型.在不同的网络结构以及任务上,都有很多不同的具体实现方式.其中比较有名的图网络模型包括图卷积网络 ( Graph Convolutional Network, GCN ) [Kipf et al., 2016]、图注意力网络 ( Graph Attention Network, GAT ) [Veličković et al., 2017] 消息传递神经网络 ( Message Passing Neural Network, MPNN ) [Gilmer et al., 2017] 等.关于图网络的综述可以参考文献 [Battaglia et al., 2018].
生成对抗网络
产生于2014年,论文地址 Ian J. Goodfellow
生动的白话例子:莫烦教程
又李宏毅课上举的例子:类似生物的拟态,枯叶蝶进化中不断地去模仿叶子的形态以逃避天敌的捕食.
简而言之,有两个网络,生成网络(generative network,即枯叶蝶自身形态的进化)和对抗网络(adversarial network,即天敌对它的分辨),生成网络负责生成样本(可能依据一个分布得到的随机数来生成),而对抗网络负责判断这个样本是真实样本还是生成样本,两个网络共同训练.
一. 概率生成模型
概率生成模型(Probabilistic Generative Model),简称生成模型,指一系列用于随机生成可观测数据的模型.假设在一个连续或离散的高维空间 中,存在一个随机向量 服从一个未知的数据分布 .生成模型是根据一些可观测的样本 来学习一个参数化的模型 来近似未知分布 ,并可以用这个模型来生成一些样本,使得“生成”的样本和“真实”的样本尽可能地相似.生成模型通常包含两个基本功能:概率密度估计和生成样本(即采样).
1.1 密度估计
给定一组数据 ,假设它们都是独立地从相同的概率密度函数为 的未知分布中产生的.密度估计(Density Estimation)是根据数据集 来估计其概率密度函数 .
直接建模 比较困难.因此,我们通常通过引入隐变量 来简化模型,这样密度估计问题可以转换为估计变量 的两个局部条件概率 和 .一般为了简化模型,假设隐变量 的先验分布为标准高斯分布 .隐变量 的每一维之间都是独立的.在这个假设下,先验分布 中没有参数.因此,密度估计的重点是估计条件分布 .
1.2 生成样本
生成样本就是给定一个概率密度函数为 的分布,生成一些服从这个分布的样本,也称为采样.
在得到两个变量的局部条件概率 和 之后,我们就可以生成数据 ,具体过程可以分为两步进行:
- 根据隐变量的先验分布 进行采样,得到样本 .
- 根据条件分布 进行采样,得到样本 .
为了便于采样,通常 不能太过复杂.因此,另一种生成样本的思想是从一个简单分布 ( 比如标准正态分布 ) 中采集一个样本 , 并利用一个深度神经网络 使得 服从 这样,我们就可以避免密度估计问题,并有效降低生成样本的难度,这正是生成对抗网络的思想.
1.3 生成对抗网络
一种无监督学习.注意到生成网络与真实样本是未接触的,判别网络根据真实样本来更新参数,而生成网络根据判别网络来更新参数.
1.3.1 显式密度模型和隐式密度模型
一些深度生成模型,比如变分自编码器、深度信念网络等,都是显示地构建出样本的密度函数 ,并通过最大似然估计来求解参数,称为显式密度模型(Explicit Density Model).
如果只是希望有一个模型能生成符合数据分布 的样本,那么可以不显式地估计出数据分布的密度函数.假设在低维空间 中有一个简单容易采样的 分布 通常为标准多元正态分布 .我们用神经网络构建一个映射函数 ,称为生成网络.利用神经网络强大的拟合能力,使得 服从分布 .这种模型就称为隐式密度模型 ( Implicit Density Model ).所谓隐式模型就是指并不显式地建模 ,而是建模生成过程.
1.3.2 网络分解
生成对抗网络(Generative Adversarial Networks,GAN)[Goodfellowet al.,2014]是通过对抗训练的方式来使得生成网络产生的样本服从真实数据分布.在生成对抗网络中,有两个网络进行对抗训练.一个是判别网络,目标是尽量准确地判断一个样本是来自于真实数据还是由生成网络产生;另一个是生成网络,目标是尽量生成判别网络无法区分来源的样本.
1.3.2.1 判别网络
判别网络 ( Discriminator Network ) 的目标是区分出一个样本 是来自于真实分布 还是来自于生成模型 ,因此判别网络实际上是一个二分类的分类器.用标签 来表示样本来自真实分布, 表示样本来自生成模型,判别网络 的输出为 属于真实数据分布的概率,即
则样本来自生成模型的概率为
.给定一个样本 表示其来自于 还是 .判别网络的
目标函数为最小化交叉嫡,即
假设分布 是由分布 和分布 等比例混合而成,即
,则上式等价于
其中 和 分别是生成网络和判别网络的参数.
>回忆交叉熵定义: ,其中 为真实值, 为以 为参数, 为输入的模型输出的估计值.
1.3.2.2 生成网络
生成网络(Generator Network)的目标刚好和判别网络相反,即让判别网络将自己生成的样本判别为真实样本.
上面的这两个目标函数是等价的.但是在实际训练时,一般使用前者,因为其梯度性质更好.我们知道,函数 在 接近 时的梯度要比接近 时的梯度小很多,接近“饱和”区间.这样,当判别网络 以很高的概率认为生成网络 产生的样本是“假”样本时, 即 , 目标函数关于 的 梯度反而很小,从而不利于优化.
1.3.2.3 训练
和单目标的优化任务相比,生成对抗网络的两个网络的优化目标刚好相反.因此生成对抗网络的训练比较难,往往不太稳定.一般情况下,需要平衡两个网络的能力.对于判别网络来说,一开始的判别能力不能太强,否则难以提升生成网络的能力.但是,判别网络的判别能力也不能太弱,否则针对它训练的生成网络也不会太好.在训练时需要使用一些技巧,使得在每次迭代中,判别网络比生成网络的能力强一些,但又不能强太多.
生成对抗网络的训练流程如下图所示.每次迭代时,判别网络更新 𝐾 次而生成网络更新一次,即首先要保证判别网络足够强才能开始训练生成网络.在实践中 𝐾 是一个超参数,其取值一般取决于具体任务.
1.3.2.4 难点
生成对抗网络训练的难点在于,不像一般的loss function,我们只要看loss收不收敛就知道训练效果.GAN中最后判断网络无法辨别样本来自真实网络还是生成网络,可能是由于生成网络训练得很好,但也可能是由于判别网络训练得太差,反之亦然.
DRL
强化学习基础
一. 有模型数值迭代
1.1 度量空间与压缩映射
1.1.1 度量空间及其完备性
度量 ( metric,又称距离 ),是定义在集合上的二元函数.对于集合 ,其上的度量 ,需要满足
- 非负性:对任意的 , 有
- 同一性:对任意的 , 如果 , 则
- 对称性:对任意的 , 有
- 三角不等式: 对任意的 , 有 .
有序对 又称为度量空间 (metric space).我们来看一个度量空间的例子.考虑有限 Markov 决策过程状态函数 ,其所有可能的取值组成集合 ,定义 如下: 可以证明, 是 上的一个度量.(证明: 非负性、同一性、对称性是显然的.由于 有 可得三角不等式.)所以, 是一个度量空间.对于一个度量空间,如果 Cauchy 序列都收敛在该空间内,则称这个度量空间是完备的(complete).对于度量空间 也是完备的.(证明: 考虑其中任意 Cauchy 列 ,即对任意的正实数 ,存在正整数 使得任意的 ,均有 对于 ,所以 是 Cauchy 列.由实数集的完备性,可以知道 收敛于某个实数,记这个实数为 .所以,对于 ,存在正整数 ,对于任意 ,有 取 ,有 ,所以 收敛于 ,而 ,完备性得证).
1.1.2 压缩映射与Bellman算子
本节介绍压缩映射的定义,并证明 Bellman 期望算子和 Bellman 最优算子是度量空间 上的压缩映射.
对于一个度量空间 和其上的一个映射 ,如果存在某个实数 ,使得对于任意的 ,都有 则称映射 是压缩映射 ( contraction mapping, 或 Lipschitzian mapping).其中的实数 被称为 Lipschitz 常数.
※ Bellman期望方程
用状态价值函数表示状态价值函数: 用动作价值函数表示动作价值函数:
※ Bellman最优方程
用最优状态价值函数表示最优状态价值函数: 用最优动作价值函数表示最优动作价值函数:
这两个方程都有用状态价值表示状态价值的形式.根据这个形式,我们可以为度量空间 定义 Bellman 期望算子 和 Bellman 最优算子.
给定策略 的 Bellman 期望算子 Bellman 最优算子 : 下面我们就来证明,这两个算子都是压缩映射.
首先来看 期望算子 .由 的定义可知,对任意的 ,有 所以 考虑到 是任取的,所以有 当 时, 就是压缩映射.接下来看 Bellman 最优算子 .要证明 是压缩映射,需要用到下列不等式: 其中 和 是任意的以 为自变量的函数。(证明: 设 , 则 同理可证 ,于是不等式得证. 利用这个不等 式,对任意的 ,有 进而易知 ,所以 是压缩映射.
1.1.3 Banach不动点定理
对于度量空间 上的映射 ,如果 使得 ,则称 是映射 的不动点 (fix point).
例如,策略 的状态价值函数 满足 Bellman 期望方程,是 Bellman 期望算子 的不动点.最优状态价值 满足 Bellman 最优方程,是 Bellman 最优算子 的不动点.
完备度量空间上的压缩映射有非常重要的结论: Banach 不动点定理. Banach 不动点定理(Banach fixed-point theorem, 又称压缩映射定理, compressed mapping theorem) 的内容是: 是非空的完备度量空间, 是一个压缩映射,则映射 在 内有且仅有一个不动点 .更进一步,这个不动点可以通过下列方法求出:从 内的任意一个元素 开始,定义迭代序列 ,这个序列收敛,且极限为 .
证明:考虑任取的 及其确定的列 ,我们可以证明它是 Cauchy 序列.对于任意的 且 ,用距离的三角不等式和非负性可知, 再反复利用压缩映射可知,对于任意的正整数 有 ,代人得: 由于 ,所以上述不等式右端可以任意小,得证.
Banach 不动点定理给出了求完备度量空间中压缩映射不动点的方法:从任意的起点开始,不断迭代使用压缩映射,最终就能收敛到不动点.并且在证明的过程中,还给出了收敛速度,即迭代正比于 的速度收敛 其中 是迭代次数).在 节我们已经证明 是完备的度量空间,而 节又证明了 Bellman 期望算子和 最优算子是压缩映射,那么就可以用迭代的方法求 Bellman 期望算子和 Bellman 最优算子的不动点.于 期望算子的不动点就是策略价值,Bellman 最优算子的不动点就是最优价值,所以这就意味着我们可以用迭代的方法求得策略的价值或最优价值.在后面的小节中,就来具体看看求解的算法.
1.2 有模型策略迭代
本节介绍在给定动力系统 的情况下的策略评估、策略改进和策略迭代.策略评估、策略改进和策略迭代分别指以下操作.
- 策略评估(policy evaluation): 对于给定的策略 ,估计策略的价值, 包括动作价值和状态价值
- 策略改进(policy improvement): 对于给定的策略 ,在已知其价值函数的情况 找到一个更优的策略
- 策略迭代(policy iteration): 综合利用策略评估和策略改进,找到最优策略
1.2.1 策略评估
本节介绍如何用迭代方法评估给定策略的价值函数.如果能求得状态价值函数,那么就能很容易地求出动作价值函数.由于状态价值函数只有 个自变量,而动作价值函数有 个自变量,所以存储状态价值函数比较节约空间.
用迭代的方法评估给定策略的价值函数的算法如算法 1-1 所示.算法 1-1 一开始初始化状态价值函数 ,并在后续的迭代中用 期望方程的表达式更新一轮所有状态的状态价值函数.这样对所有状态价值函数的一次更新又称为一次扫描(sweep).在第 次扫描时,用 的值来更新 的值,最终得到一系列的 .
算法 1-1:有模型策略评估迭代算法
输入: 动力系统 , 策略 输出:状态价值函数 的估计值 参数: 控制迭代次数的参数(如误差容忍度 或最大迭代次数
- (初始化) 对于 ,将 初始化为任意值 (比如 0 ).如果有终止状态,将终止状态初始化为 0,即
- (迭代) 对于 ,迭代执行以下步骤 2.1 对于 , 逐一更新 ,其中. 2.2 如果满足迭代终止条件 (如对 均有 ,或达到最大迭代次数 ,则跳出循环
值得一提的是,算法 1-1 没必要为每次捉描都重新分配一套空间来存储.一种优化的方法是,设置奇数次迭代的存储空间和偶数次迭代的存储空间,一开始初始化偶数次存储空间,当 是奇数时,用偶数次存储空间来更新奇数次存储空间; 当 是偶数时, 用奇数次存储空间来更新偶数次存储空间.这样,一共只需要两套存储空间就可以完成算法.
1.2.2 策略改进
对于给定的策略 ,如果得到该策略的价值函数,则可以用策略改进定理得到一个改进的策略.
策略改进定理的内容如下:对于策略 和 ,如果 则 ,即 在此基础上,如果存在状态使得第一式的不等号是严格小于号,那么就存在状态使得第二式中的不等号也是严格小于号.
证明: 考虑到第一个不等式等价于 其中的期望是针对用策略 生成的轨迹中,选取 的那些轨迹而言的.进而有 考虑到 所以 进而有 严格不等号的证明类似.
对于一个确定性策略 ,如果存在着 ,使得 ,那么我们可以构造一个新的确定策略 ,它在状态 做动作 ,而在除状态 以外的状态的动作都和策略一样.可以验证,策略 和 满足策略改进定理的条件.这样,我们就得到了一个比策略 更好的策略 .这样的策略更新算法可以用算法 1-2 来表示.
算法 1-2:有模型策略改进算法
输入: 动力系统 ,策略 及其状态价值函数 输出:改进的策略 ,或策略 已经达到最优的标志
- 对于每个状态 ,执行以下步骤: 1.1 为每个动作 ,求得动作价值函数 1.2 找到使得 最大的动作 ,即
- 如果新策略 和旧策略 相同,则说明旧策略巳是最优; 否则,输出改进的新策略
值得一提的是,在算法 1-2 中,旧策略 和新策略 只在某些状态上有不同的动作值, 新策略 可以很方便地在旧策略 的基础上修改得到.所以,如果在后续不需要使用旧策略的情况下,可以不为新策略分配空间.
1.2.3 策略迭代
策略迭代是一种综合利用策略评估和策略改进求解最优策略的迭代方法.
见算法 1-3 ,策略迭代从一个任意的确定性策略 开始,交替进行策略评估和策略改进.这里的策略改进是严格的策略改进,即改进后的策略和改进前的策略是不同的 对于状态空间和动作空间均有限的 Markov 决策过程,其可能的确定性策略数是有限的.由于确定性策略总数是有限的,所以在迭代过程中得到的策略序列 一定能收敛,使得到某个 ,有 (即对任意的 均有 .由于在 的情况下, ,进而 ,满足 Bellman 最优方程.因此, 就是最优策略.这样就证明了策略迭代能够收敛到最优策略.
算法 1-3:有模型策略迭代
输入: 动力系统 输出:最优策略
- (初始化)将策略 初始化为一个任意的确定性策略.
- (迭代) 对于 , 执行以下步骤 2.1 (策略评估)使用策略评估算法,计算策略 的状态价值函数 2.2 (策略更新)利用状态价值函数 改进确定性策略 ,得到改进的确定性策略 .如果 .(即对任意的 均有 ),则迭代完成,返回策略 为最终的最优策略.
策略迭代也可以通过重复利用空间来节约空间.为了节约空间,在各次迭代中用相同的空间 来存储状态价值函数,用空间 来存储确定性策略.
1.3 有模型价值迭代
价值迭代是一种利用迭代求解最优价值函数进而求解最优策略的方法.在 节介绍的策略评估中,迭代算法利用 Bellman 期望方程迭代求解给定策略的价值函数.与之相对,本节将利用 Bellman 最优方程迭代求解最优策略的价值函数,并进而求得最优策略.
与策略评估的情形类似,价值迭代算法有参数来控制迭代的终止条件,可以是误差容忍度 或是最大迭代次数 .
算法 1-4 给出了一个价值迭代算法.这个价值迭代算法中先初始化状态价值函数,然后用 Bellman 最优方程来更新状态价值函数.根据第 1.1 节的证明,只要迭代次数足够多,最终会收敘到最优价值函数.得到最优价值函数后,就能很轻易地给出确定性的最优策略.
算法 1-4: 有模型价值迭代算法
输入:动力系统 愉出:最优策略估计 参数:策略评估需要的参数
- (初始化) 任意值, .如果有终止状态,
- (迭代) 对于 ,执行以下步骤 2.1 对于 ,逐一更新 2.2 如果满足误差容忍度 (即对于 均有 或达到最大迭代次数 (即 ,则跳出循环
- (策略) 根据价值函数输出确定性策略 ,使得
与策略评估的迭代求解类似,价值迭代也可以在存储状态价值函数时重复使用空间.算法 1-5 给出了重复使用空间以节约空间的版本.
算法 1-5:有模型价值迭代 (节约空间的版本 )
输入: 动力系统 输出:最优策略 参数:策略评估需要的参数
- (初始化) 任意值, .如果有终止状态,
- (迭代) 对于 ,执行以下步骤
2.1 对于使用误差容忍度的情况,初始化本次迭代观测到的最大误差
2.2 对于 执行以下操作:
- 计算新状态价值
- 对于使用误差容忍度的情况,更新本次迭代观测到的最大误差
- 更新状态价值函数 2.3 如果满足误差容忍度(即 或达到最大迭代次数 即 ,则跳出循环
- (策略) 根据价值函数输出确定性策略:
二. 回合更新价值迭代
本章开始介绍无模型的机器学习算法.无模型的机器学习算法在没有环境的数学描述的情况下,只依靠经验(例如轨迹的样本) 学习出给定策略的价值函数和最优策略.在现实生活中,为环境建立精确的数学模型往往非常困难.因此,无模型的强化学习是强化学习的主要形式.
根据价值函数的更新时机,强化学习可以分为回合更新算法和时序差分更新算法这两类.回合更新算法只能用于回合制任务,它在每个回合结束后更新价值函数.本章将介绍回合更新算法,包括同策回合更新算法和异策回合更新算法.
2.1 同策回合更新
本节介绍同策回合更新算法.与有模型迭代更新的情况类似,我们也是先学习同策策略评估,再学习最优策略求解.
2.1.1 同策回合更新策略评估
本节考虑用回合更新的方法学习给定策略的价值函数.我们知道,状态价值和动作价值分别是在给定状态和状态动作对的情况下回报的期望值.回合更新策略评估的基本思路使用 Monte Carlo 方法来估计这个期望值.具体而言,在许多轨迹样本中,如果某个状态(或状态动作对) 出现了 次,其对应的回报值分别为 ,那么可以估计其状态价 (或动作价值) 为 .
无模型策略评估算法有评估状态价值函数和评估动作价值函数两种版本.在有模型情况下,状态价值和动作价值可以互相表示;但是在无模型的情况下,状态价值和动作价值并不能互相表示.我们已经知道,任意策略的价值函数满足 Bellman 期望方程.借助于动力 (某个状态转移分布)的表达式,我们可以用状态价值函数表示动作价值函数;借助于策略 的表达式,我们可以用动作价值函数表示状态价值函数.所以,对于无模型的策略评估, 的表达式未知,只能用动作价值表示状态价值,而不能用状态价值表示动作价值.另外,由于策略改进可以仅由动作价值函数确定,因此在学习问题中,动作价值函数往往更加重要.
在同一个回合中,多个步骤可能会到达同一个状态 (或状态动作对),即同一状态(或状态动作对)可能会被多次访问.对于不同次的访问,计算得到的回报样本值很可能不相 同.如果采用回合内全部的回报样本值更新价值函数, 则称为每次访问回合更新(every visit Monte Carlo update); 如果每个回合只采用第一次访问的回报样本更新价值函数,则称为首次访问回合更新( first visit Monte Carlo update).每次访问和首次访问在学习过程中的中间值并不相同,但是它们都能收敛到真实的价值函数.
首先来看每次访问回合更新策略评估算法.算法 2-1 给出了每次访问更新求动作价值的算法.我们来逐步看一下算法 2-1 .算法 2-1 首先对动作价值 进行初始化. 可以初始化为任意的值,因为在第一次更新后 的值就和初始化的值没有关系,所以将 初始化为什么数无关紧要.接着,算法 2-1 进行回合更新.与有模型迭代更新的情形类似,这里可以用参数来控制回合更新的回合数.例如,可以使用最大回合数 或者精度指标 .在生成好轨迹后,算法 2-1 采用逆序的方式更新 .这里采用逆序是为了使用 这一关系来更新 值,以减小计算复杂度.
算法 2-1:每次访问回合更新评估策略的动作价值
输入:环境(无数学描述),策略 输出:动作价值函数
- (初始化) 初始化动作价值估计 任意值, ,若更新价值需要使用计数 器,则初始化计数器 。
- (回合更新) 对于每个回合执行以下操作 2.1 (采样) 用策略 生成轨迹 2.2 (初始化回报) 2.3 (逐步更新)对 执行以下步骤 1. (更新回报) 2. (更新动作价值)更新 以减小 (如 , .
算法 2-1 在更新动作价值时,可以采用增量法来实现 Monte Carlo 方法.增量法的原理如下:如前 次观察到的回报样本是 ,则前 次价值函数的估计值为 ;如果第 次的回报样本是 ,则前 次价值函数的估计值为 可以证明, .所以,只要知道出现的次数 ,就可以用新的观测 把旧的平均值 更新为新的平均值 .因此,增量法不仅需要记录当前的价值估计 还需要记录状态动作对出现的次数 .在算法 2-1 中,状态动作对 的出现次数记录在 里,每次更新时将计数值加 1,再更新平均值 ,这样就实现了增量法.
求得动作价值后,可以用 Bellman 期望方程求得状态价值.状态价值也可以直接用回合更新的方法得到.算法 2-2 给出了每次访问回合更新评估策略的状态价值的算法.它与算法 2-1 的区别在于将 替换为了 ,计数也相应做了修改.
算法 4-2: 每次访问回合更新评估策略的状态价值
输入: 环境(无数学描述),策略 输出:状态价值函数
- (初始化)初始化状态价值估计 任意值, ,若更新价值时需要使用计数器则更新初始化计数器
- (回合更新)对于每个回合执行以下操作
2.1 (采样)用策略 生成轨迹
2.2 (初始化回报) 。
2.3 (逐步更新) 对 执行以下步骤 :
- (更新回报)
- (更新状态价值)更新 以减小 如 ,
首次访问回合更新策略评估是比每次访问回合更新策略评估更为历史悠久、更为全面研究的算法.算法 2-3 给出了首次访问回合更新求动作价值的算法.这个算法和算法 2-1 的区别在于,在每次得到轨迹样本后,先找出各状态分别在哪些步骤被首次访问.在后续的更新过程中,只在那些首次访问的步骤更新价值函数的估计值.
算法 2-3:首次访问回合更新评估策略的动作价值
输入: 环境(无数学描述),策略 输出:动作价值函数 .
- (初始化)初始化动作价值估计 任意值, ,若更新动作价值时需要计数器,则初始化计数器 .
- (回合更新)对于每个回合执行以下操作
2.1 (采样)用策略 生成轨迹
2.2 (初始化回报)
2.3 (初始化首次出现的步骤数)
2.4 (统计首次出现的步骤数)对于 ,执行以下步骤:如果 ,则
2.5 (逐步更新)对 ,执行以下步骤:
- (更新回报)
- (首次出现则更新)如果 ,则更新 以减小 如.
与每次访问的情形类似,首次访问也可以直接估计状态价值,见算法 2-4 .当然也可借助 Bellman 期望方程用动作价值求得状态价值.
算法 2-4:首次访问回合更新评估策略的状态价值
输入:环境(无数学描述),策略 输出:状态价值函数
- (初始化)初始化状态价值估计 任意值, ,若更新价值时需要使用计数器,更新初始化计数器
- (回合更新)对于每个回合执行以下操作
2.1 (采样)用策略 生成轨迹
2.2 (初始化回报)
2.3 (初始化首次出现的步骤数)
2.4 (统计首次出现的步骤数)对于 ,执行以下步骤:如果 ,则
2.5 (逐步更新)对 ,执行以下步骤 :
- (更新回报)
- (首次出现则更新)如果 ,则更新 以减小 如 ,.
TODO:起始探索与柔性策略
2.2 异策回合更新
本节考虑异策回合更新.异策算法允许生成轨迹的策略和正在被评估或被优化的策路不是同一策略.我们将引人异策算法中一个非常重要的概念——重要性采样,并用其进行 策略评估和求解最优策略.
2.2.1 重要性采样
在统计学上,重要性采样(importance sampling)是一种用一个分布生成的样本来估计另一个分布的统计量的方法.在异策学习中,将要学习的策略 称为目标策略(target policy),将用来生成行为的另一策略 称为行为策略(behavior policy ).重要性采样可以用行为策略生成的轨迹样本生成目标策略的统计量.
现在考虑从 开始的轨迹 .在给定 的条件下,采用 策略 和策略 生成这个轨迹的概率分别为: 我们把这两个概率的比值定义为重要性采样比率 (importance sample ratio): 这个比率只与轨迹和策略有关,而与动力无关.为了让这个比率对不同的轨迹总是有意义,我们需要使得任何满足 的 ,均有 这样的关系可以记为.
对于给定状态动作对 的条件概率也有类似的分析.在给定 的条件下,采用策略 和策略 生成这个轨迹的概率分别为: 其概率的比值为 回合更新总是使用 Monte Carlo 估计价值函数的值.同策回合更新得到 个回报 后,用平均值 来作为价值函数的估计.这样的方法实际上默认了这 个回报是等概率出现的.类似的是,异策回合更新用行为策略 得到 个回报 ,这个回报值对于行为策略 是等概率出现的.但是这 个回报值对于目标策略 不是等概率出现的.对于目标策略 而言,这 个回报值出现的概率正是各轨迹的重要性采样比率.这样,我们可以用加权平均来完成 Monte Carlo 估计.具体而言,若 是回报样本 对应的权重(即轨迹的重要性采样比率),可以有以下两种加权方法.
- 加权重要性采样(weighted importance sampling ), 即
- 普通重要性采样(ordinary importance sampling), 即
这两种方法的区别在于分母部分.对于加权重要性采样,如果某个权重 ,那么它不会让对应的 参与平均,并不影响整体的平均值;对于普通重要性采样,如果某个权重 ,那么它会让 0 参与平均,使得平均值变小.无论是加权重要性采样还是普通重要性采样,当回报样本数增加时,仍然可以用增量法将旧的加权平均值更新为新的加权平均值.对于加权重要性采样,需要将计数值替换为权重的和,以 的形式作更新.对于普通重要性采样而言,实际上就是对 加以平均,与直接没有加权情况下对 加以平均没有本质区别.它的更新形式为 :
2.2.2 异策回合更新策略评估
基于 2.2.1 节给出的重要性采样,算法 2-7 给出了每次访问加权重要性采样回合更新策略评估算法.这个算法在初始化环节初始化了权重和 与动作价值 ,然后进行回合更新.回合更新需要借助行为策略 .行为策略 可以每个回合都单独设计,也可以为整个算法设计一个行为策略,而在所有回合都使用同一个行为策略.用行为策略生成轨迹样本后,逆序更新回报、价值函数和权重值.一开始权重值 设为 1 ,以后会越来越小.如果某次权重值变为 0 (这往往是因为 ,那么以 后的权重值就都为 0 ,再循环下去没有意义.所以这里设计了一个检查机制.事实上,这个检查机制保证了在更新 时权重和 是必需的.如果没有检查机制,则可能在更新 时,更新前和更新后的 值都是 0 ,进而在更新 时出现除零错误.加这个检查机制避免了这样的错误.
算法 2-7: 每次访问加权重要性采样异策回合更新评估策略的动作价值
- (初始化)初始化动作价值估计 任意值, ,如果需要使用权重和,则初始化权重和
- (回合更新)对每个回合执行以下操作
2.1 (行为策略)指定行为策略 ,使得
2.2 (采样)用策略 生成轨迹:
2.3 (初始化回报和权重)
2.4 对于 执行以下操作:
- (更新回报)
- (更新价值)更新 以减小 如 ,
- (更新权重)
- (提前终止)如果 ,则结東步骤 2.4 的循环
在算法 2-7 的基础上略作修改,可以得到首次访问的算法、普通重要性采样的算法和估计状态价值的算法,此处略过.
2.2.3 异策回合更新最优策略求解
接下来介绍最优策略的求解.算法 2-8 给出了每次访问加权重要性采样异策回合最优策略求解算法.它和其他最优策略求解算法一样,都是在策略估计算法的基础上加上策略改进得来的.算法 2-8 的迭代过程中,始终让 是一个确定性策略.所以,在回合更新的过程中,任选一个策略 都满足 .这个柔性策略可以每个回合都分别选取,也可以整个程序共用一个.由于采用了确定性的策略,则对于每个状态 都有一个 使得 ,而其他 .算法 2-8 利用这一性质来更新权重并判断权重是否为 0 .如果 ,则意味着 ,更新后的权重为 0 ,需要退出循环以避免除零错误;若 ,则意味着 ,所以权重更新语句 就可以简化为 .
算法 2-8: 每次访问加权重要性采样异策回合更新最优策略求解
- (初始化)初始化动作价值估计 任意值, ,如果需要使用权重和,初始化权重和
- (回合更新)对每个回合执行以下操作
2.1 (柔性策略)指定 为任意柔性策略
2.2 (采样)用策略 生成轨迹:
2.3 (初始化回报和权重)
2.4 对
- (更新回报)
- (更新价值)更新 以减小 如 ,
- (策略更新)
- (提前终止)若 则退出步骤 2.4
- (更新权重)
算法 2-8 也可以修改得到首次访问的算法和普通重要性采样的算法,此处略过.
三. 时序差分价值迭代
本章介绍另外一种学习方法一时序差分更新.时序差分更新和回合更新都是直接采用经验数据进行学习,而不需要环境模型.时序差分更新与回合更新的区别在于,时序差 分更新汲取了动态规划方法中"自益"的思想,用现有的价值估计值来更新价值估计,不需要等到回合结束也可以更新价值估计.所以,时序差分更新既可以用于回合制任务,也可以用于连续性任务.
本章将介绍时序差分更新方法,包括同策时序差分更新方法和异策时序差分更新方法, 每种方法都先介绍简单的单步更新,再介绍多步更新.最后,本章会涉及基于资格迹的学习算法.
3.1 同策时序差分更新
本节考虑无模型同策时序差分更新.与无模型回合更新的情况相同,在无模型的情况下动作价值比状态价值更为重要,因为动作价值能够决定策略和状态价值,但是状态价值得不到动作价值.
本节考虑无模型同策时序差分更新。与无模型回合更新的情况相同,在无模型的情况下动作价值比状态价值更为重要,因为动作价值能够决定策略和状态价值,但是状态价值得不到动作价值。 从给定策略 的情况下动作价值的定义出发,我们可以得到下式: 在上一章的回合更新学习中,我们依据 ,用 Monte Carlo 方法来估计价值函数.为了得到回报样本,我们要从状态动作对 出发一直采样到回合结束.单步时序差分更新将依据 ,只需要采样一步,进而用 ,来估计回报样本的值.为了与由奖励直接计算得到的无偏回报样本 进行区别,本书用字母 表示使用自益得到的有偏回报样本.
基于以上分析,我们可以定义时序差分目标.时序差分目标可以针对动作价值定义,也可以针对状态价值定义.对于动作价值,其单步时序差分目标定义为 其中 的上标 表示是对动作价值定义的,下标 表示用 的估计值来估计 .如果 是终止状态,默认有 .这样的时序差分目标可以进一步扩展到多步的情况. n 步时序差分目标 定义为 在不强调步数的情况下, 可以简记为 或 .对于回合制任务,如果回合的步数 ,则我们可以强制让 这样,上述时序差分目标的定义式依然成立,实际上 达到了 的效果。本书后文都做这样的假设。类似的是,对于 状态价值,定义 步时序差分目标为 它也可以简记为 或 .
3.1.1 时序差分更新策略评估
本节考虑利用时序差分目标来评估给定策略的价值函数.回顾在同策回合更新策略评估中,我们用形如 的增量更新来学习动作价值函数,试图减小 .在这个式子中, 是回报样本.在时序差分中,这个量就对应着 .因此,只需在回合更新策略评估算法的基础上,将这个增量更新式中的回报 替换为时序差分目标 ,就可以得到时序差分策略评估算法了.
时序差分目标既可以是单步时序差分目标,也可以是多步时序差分目标.我们先来看单步时序差分目标.
算法 3-1 给出了用单步时序差分更新评估策略的动作价值的算法.这个算法有一个 ,它是一个正实数,表示学习率.在上一章的回合更新中,这个学习率往往是 ,它和状态动作对有关,并且不断减小.在时序差分更新中,也可以采用这样不断减小的学习率.不过,考虑到在时序差分算法执行的过程中价值函数会越来越准确,进而基于价值函数估计得到的价值函数也会越来越准确,因此估计值的权重可以越来越大.所以,算法 3-1 采用了一个固定的学习率 .这个学习率一般在 .当然,学习率也可以不是常数.在有些问题中,让学习率巧妙地变化能得到更好的效果.引入学习率 后,更新式可以表示为:
算法 3-1: 单步时序差分更新评估策略的动作价值
输入: 环境(无数学描述)、策略 输出:动作价值函数 参数:优化器(隐含学习率 ),折扣因子 ,控制回合数和回合内步数的参数.
- (初始化) 任意值, .如果有终止状态,令 .
- (时序差分更新)对每个回合执行以下操作
2.1 (初始化状态动作对)选择状态 ,再根据输入策略 确定动作
2.2 如果回合未结東(例如未达到最大步数、S不是终止状态), 执行以下操作:
- (采样)执行动作 ,观测得到奖励 和新状态
- 用输入策略 确定动作
- (计算回报的估计值)
- (更新价值)更新 以减小 如
- .
在具体的更新过程中,除了学习率 和折扣因子 外,还有控制回合数和每个回合步数的参数.我们知道,时序差分更新不仅可以用于回合制任务,也可以用于非回合制任务.对于非回合制任务,我们可以自行将某些时段抽出来当作多个回合,也可以不划分回合当作只有一个回合进行更新.类似地,算法 3-2 给出了用单步时序差分方法评估策略状态价值的算法.
算法 3-2 单步时序差分更新评估策略的状态价值
输入:环境(无数学描述)、策略 输出:状态价值函数 参数:优化器(隐含学习率 ),折扣因子 ,控制回合数和回合内步数的参数
- (初始化) 任意值, .如果有终止状态,
- (时序差分更新)对每个回合执行以下操作
2.1 (初始化状态)选择状态
2.2 如果回合未结東(例如未达到最大步数、S不是终止状态),执行以下操作:
- 根据输入策略 确定动作
- (采样)执行动作 ,观测得到奖励 和新状态
- (计算回报的估计值)
- (更新价值)更新 以减小 如
在无模型的情况下,用回合更新和时序差分更新来评估策略都能渐近得到真实的价值函数.它们各有优劣.目前并没有证明某种方法就比另外一种方法更好.根据经验,学习 率为常数的时序差分更新常常比学习率为常数的回合更新更快收敛.不过时序差分更新对环境的 Markov 性要求更高.
我们通过一个例子来比较回合更新和时序差分更新.考虑某个 Markov 奖励过程,我们得到了它的 5 个轨迹样本如下(只显示状态和奖励): 使用回合更新得到的状态价值估计值为 ,而使用时序差分更新得到的状态价值估计值为 .这两种方法对 的估计是一样的,但是对于 的估计有明显不同:回合更新只考虑其中两个含有 的轨迹样本,用这两个轨迹样本回报来估计状态价值;时序差分更新认为状态 下一步肯定会到达状态 ,所以可以利用全部轨迹样本来估计 ,进而由 推出 .试想,如果这个环境真的是 Markov 决策过程,并且我们正确地识别出了状态空间 ,那么时序差分更新方法可以用更多的轨迹样本来帮助估计 的状态价值,这样可以更好地利用现有的样本得到更精确的估计.但是,如果这个环境其实并不是 Markov 决策过程,或是 并不是其真正的状态空间.那么也有可能 之后获得的奖励值其实和这个轨迹是否到达过 有关,例如如果达到过 则奖励总是为 0 .这种情况下,回合更新能够不受到这一错误的影响,只采用正确的信息,从而不受无关信息的干扰,得到正确的估计.这个例子比较了回合更新和时序差分更新的部分利弊。
接下来看如何用多步时序差分目标来评估策略价值.算法 3-3 和算法 3-4 分别给出了用多步时序差分评估动作价值和状态价值的算法.实际实现时,可以让 和 共享同一存储空间,这样只需要 份存储空间.
算法 3-3 步时序差分更新评估策略的动作价值
输入:环境(无数学描述)、策略 输出:动作价值估计 参数:步数 ,优化器 (隐含学习率 ),折扣因子 ,控制回合数和回合内步数的参数
- (初始化) 任意值 .如果有终止状态,令
- (时序差分更新)对每个回合执行以下操作
2.1 (生成 n 步)用策略 生成轨迹 (若遇到终止状态,则令后续奖励均为0,状态均为
2.2 对于 依次执行以下操作,直到 :
- 若 ,则根据 决定动作
- (更新时序差分目标)
- (更新价值)更新 以减小
- 若 ,则执行 ,得到奖励 和下一状态 ;若 ,
算法 3-4 步时序差分更新评估策略的状态价值
输入: 环境(无数学描述)、策略 输出:状态价值估计 参数:步数 ,优化器(隐含学习率 ),折扣因子 ,控制回合数和回合内步数的参数
- (初始化) 任意值, 如果有终止状态,令
- (时序差分更新) 对每个回合执行以下操作
2.1(生成 步)用策略 生成轨迹 (若遇到终止状态,则令后续奖励均为 0 ,状态均为 ).
2.2 对于 依次执行以下操作,直到 :
- (更新时序差分目标 )
- (更新价值)更新 以减小
- 若 ,则根据 决定动作 并执行,得到奖励 和下一状态 ;若 ,令 .
3.1.2 SARSA算法
本节我们采用同策时序差分更新来求解最优策略.首先我们来看 “状态 / 动作 / 奖励 状态 / 动作”(State-Action-Reward-State-Action, SARSA)算法.这个算法得名于更新涉及的随机变量 .该算法利用 得到单步时序差分目标 ,进而更新 .该算法的更新式为: 其中 是学习率.算法 3-5 给出了用 SARSA 算法求解最优策略的算法.SARSA 算法就是在单步动作价值估计的算法的基础上,在更新价值估计后更新策略.在算法 3-5 中,每当最优动作价值函数的估计 更新时,就进行策略改进,修改最优策略的估计 .策略的提升方法可以采用 贪心算法,使得 总是柔性策略.更新结束后,就得到最优动作价值估计和最优策略估计.
算法 3-5 SARSA 算法求解最优策略(显式更新策略)
与 Q-learning 区别在于,SARSA 每次算得的 在下一步会使用,也就是这一步通过 Q 函数预测得到的下一步采取的动作 也是在下一步更新时真正使用的,而 Q-learning 在下一步更新时的 是重新算的(用这一步更新后的 Q 函数算得)
输入:环境(无数学描述) 输出:最优策略估计 和最优动作价值估计 参数:优化器(隐含学习率 ),折扣因子 ,策略改进的参数(如 ),其他控制回合数和回合步数的参数
- (初始化) 任意值, .如果有终止状态,令 .用动作价值 确定策略 如使用 贪心策略
- (时序差分更新)对每个回合执行以下操作
2.1 (初始化状态动作对)选择状态 ,再用策略 确定动作
2.2 如果回合未结東(比如未达到最大步数、 不是终止状态),执行以下操作:
- (采样)执行动作 ,观测得到奖励 和新状态
- 用策略 确定动作
- (计算回报的估计值)
- (更新价值)更新 以减小 如
- (策略改进)根据 修改 如 贪心策略
其实,在同策迭代的过程中,最优策略也可以不显式存储.另外多步 SARSA 此处省略.
3.1.3 期望SARSA算法
SARSA 算法有一种变化一一期望 SARSA 算法(Expected SARSA).期望 SARSA算法与 SARSA 算法的不同之处在于,它在估计 时,不使用基于动作价值的时序差分目标 ,而使用基于状态价值的时序差分目标 .利用 Bellman 方程,这样的目标又可以表示为 与 SARSA 算法相比,期望 SARSA 需要计算 ,所以计算量比 SARSA 大.但是,这样的期望运算减小了 SARSA 算法中出现的个别不恰当决策.这样,可以避免在更新后期极个别不当决策对最终效果带来不好的影响.因此,期望 SARSA 常常有比 SARSA 更大的学习率.在很多情况下,期望 SARSA 的效果会比 SARSA 稍微好一些.
算法 3-6 给出了期望 SARSA 求解最优策略的算法,它可以视作在单步时序差分状态价值估计算法上修改得到的.期望 SARSA 对回合数和回合内步数的控制方法等都和 SARSA 相同,但是由于期望 SARSA 在更新 时不需要 ,所以其循环结构有所简化.算法中让 保持为 柔性策略.如果 很小,那么这个 柔性策略就很接近于确定性策略,则期望 SARSA 计算的 就很接近于 .
算法 3-6 期望 SARSA 求解最优策略
- (初始化) 任意值, .如果有终止状态,令 .用动作价值 确定策略 (如使用 贪心策略)
- (时序差分更新)对每个回合执行以下操作
2.1 (初始化状态)选择状态
2.2 如果回合未结東(比如未达到最大步数、S不是终止状态),执行以下操作:
- 用动作价值 确定的策略(如 贪心策略)确定动作
- (采样)执行动作 ,观测得到奖励 和新状态
- (用期望计算回报的估计值)
- (更新价值)更新 以减小 (如
3.2 异策时序差分更新
本节介绍异策时序差分更新.异策时序差分更新是比同策差分更新更加流行的算法.特别是 Q 学习算法,已经成为最重要的基础算法之一.
3.2.1 基于重要性采样的异策算法
时序差分策略评估也可以与重要性采样结合,进行异策的策略评估和最优策略求解.对于 步时序差分评估策略的动作价值和 SARSA 算法,其时序差分目标 依赖于轨迹 .在给定 的情况下,采用策略 和另外的行为策略 生成这个轨迹的概率分别为: 它们的比值就是重要性采样比率: 也就是说,通过行为策略 拿到的估计,在原策略 出现的概率是在策略 中出现概率的 倍.所以,在学习过程中,这样的时序差分目标的权重为 .将这个权重整合到时序差分策略评估动作价值算法或 SARSA算法中,就可以得到它们的重要性采样的版本.算法 3-7 给出了多步时序差分的版本,单步版本请自行整理.
算法 3-7: 步时序差分策略评估动作价值或 SARSA 算法
输入:环境(无数学描述)、策略 输出:动作价值函数 ,若是最优策略控制则还要输出策略 参数:步数 ,优化器 (隐含学习率 ),折扣因子 ,控制回合数和回合内步数的参数
- (初始化) 任意值, .如果有终止状态,令 .若是最优策略控制,还应该用 决定 (如 贪心策略)
- (时序差分更新)对每个回合执行以下操作
2.1 (行为策略)指定行为策略 ,使得
2.2 (生成 步)用策略 生成轨迹 (若遇到终止状态,则令后续奖励均为 0 , 状态均为 )
2.3 对于 依次执行以下操作,直到 :
- 若 ,则根据 决定动作
- (更新时序差分目标 )
- (计算重要性采样比率 )
- (更新价值)更新 以减小
- (更新策略)如果是最优策略求解算法,需要根据 修改
- 若 ,则执行 ,得到奖励 和下一状态 ;若 ,则令 .
我们可以用类似的方法将重要性采样运用于时序差分状态价值估计和期望 SARSA 算法中.具体而言,考虑从 开始的 步轨迹 .在给定 的条件下,采用策略 和策略 生成这个轨迹的概率分别为: 它们的比值就是时序差分状态评估和期望 SARSA 算法用到的重要性采样比率:
3.2.2 Q学习
3.1.3 节的期望 SARSA 算法将时序差分目标从 SARSA 算法的 改为 ,从而避免了偶尔出现的不当行为给整体结果带来的负面影响. Q 学习则是从改进后策略的行为出发,将时序差分目标改为 Q 学习算法认为,在根据 估计 时,与其使用 或 ,还不如使用根据 改进后的策略来更新,毕竟这样可以更接近最优价值.因此 Q 学习的更新式不是基于当前的策略,而是基于另外一个并不一定要使用的确定性策略来更新动作价值.从这个意义上看,Q 学习是一个异策算法.算法 3-8 给出了 Q 学习算法.Q 学习算法和期望 SARSA 有完全相同的程序结构,只是在更新最优动作价值的估计 时使用了不同的方法来计算目标.
算法 3-8: Q 学习算法求解最优策略
- (初始化) 任意值, .如果有终止状态,令
- (时序差分更新)对每个回合执行以下操作
2.1 (初始化状态)选择状态
2.2 如果回合未结東(例如未达到最大步数、S不是终止状态),执行以下操作:
- 用动作价值估计 确定的策略决定动作 如 贪心策略
- (采样)执行动作 ,观测得到奖励 和新状态
- (用改进后的策略计算回报的估计值)
- (更新价值和策略)更新 以减小 如
当然,Q学习也有多步的版本,其目标为: 具体省略
3.2.3 双重Q学习
上一节介绍的 Q 学习用 来更新动作价值,会导致“最大化偏差 (maximization bias),使得估计的动作价值偏大.
我们来看一个最大化偏差的例子.下所示的回合制任务中,Markov决策过程的状态空间为 ,回合开始时总是处在 状态,可以选择的动作空间 .如果选择动作 ,则可以到达状态 ,该步奖励为 0 ;如果选择动作 ,则可以达到终止状态并获得奖励 .从状态 出发,有很多可选的动作(例如有 1000 个可选的动作),但是这些动作都指向终止状态,并且奖励都服从均值为 0 、方差为 100 的正态分布.从理论上说,这个例子的最优价值函数为: , ,最优策略应当是 .但是,如果采用 Q 学习,在中间过程中会走一些弯路:在学习过程中,从 出发的某些动作会采样到比较大的奖励值,从而导致 会比较大,使得从 .这样的错误需要大量的数据才能纠正.为了解决这一问题,双重 Q 学习 ( Double Q Learning)算法使用两个独立的动作价值估计值 和 ,用 或 来 代替 Q 学习中的 .由于 和 是相互独立的估计,所以 ,其中 ,这样就消除了偏差.在双重学习的过程中, 和 都需要逐渐更新.所以,每步学习可以等概率选择以下两个更新 中的任意一个:
- 使用 来更新 以减小 和 之间的差别 (例如设定损失为 ,或采用 更新
- 使用 来更新 ,以减小 和 之间的差别 (例如设定损失为 ,或采用 更新
算法 3-9 给出了双重 Q 学习求解最优策略的算法.这个算法中最终输出的动作价值函数是 和 的平均值,即 .在算法的中间步骤,我们用这两个估计的和 来代替平均值 ,在略微简化的计算下也可以达到相同的效果.
算法 3-9: 双重 Q 学习算法求解最优策略
- (初始化) 任意值, .如果有终止状态,则令
- (时序差分更新)对每个回合执行以下操作.
2.1 (初始化状态)选择状态
2.2 如果回合未结束(比如未达到最大步数、 不是终止状态),执行以下操作:
- 用动作价值 确定的策略决定动作 (如 贪心策略)
- (采样)执行动作 ,观测得到奖励 和新状态
- (随机选择更新 或 )以等概率选择 或 中的一个动作价值函数作为更新对象,记选择的是
- (用改进后的策略更新回报的估计)
- (更新动作价值)更新 以减小 如
3.3 资格迹
资格迹是一种让时序差分学习更加有效的机制.它能在回合更新和单步时序差分更新之间折中,并且实现简单,运行有效.
3.3.1 回报
在正式介绍资格迹之前,我们先来学习 回报和基于 回报的离线 回报算法.给定 , 回报 return 是时序差分目标 按 加权平均的结果.对于连续性任务,有 对于回合制任务,则有 回报 可以看作是回合更新中的目标 和单步时序差分目标 的推广:当 时, 就是回合更新的回报;当 时, 就是单步时序差分目标.
离线 回报算法( offline -return algorithm )则是在更新价值(如动作价值 或状态价值 时,用 作为目标,试图减小 或 .它与回合更新算法相比,只是将更新的目标从 换为了 .对于回合制任务,在回合结束后为每一步 计算 ,并统一更新价值.因此,这样的算法称为离线算法( offline algorithm ).对于连续性任务,没有办法计算 ,所以无法使用离线 算法.
由于离线 回报算法使用的目标在 和 间做了折中,所以离线 回报算法的效果可能比回合更新和单步时序差分更新都要好.但是,离线 回报算法也有明显的缺点: 其一,它只能用于回合制任务,不能用于连续性任务;其二,在回合结束后要计算 ,计算量巨大.在下一节我们将采用资格迹来弥补这两个缺点.
3.3.2 TD()
TD () 是历史上具有重要影响力的强化学习算法,在离线 回报算法的基础上改进而来.以基于动作价值的算法为例,在离线 回报算法中,对任意的 ,在更新 或 时,时序差分目标 的权重是 .虽然需要等到回合结束才能计算 ,但是在知道 后就能计算 .所以我们在知道 后,就可 以用 去更新所有的 ,并且更新的权重与 成正比.
据此,给定轨迹 ,可以引入资格迹 来表示第 步的状态动作对 的单步自益结果 对每个状态动作对 需要更新的权重.资格迹(eligibility)用下列递推式定义:当 时, 当 时, 其中 是事先给定的参数.资格迹的表达式应该这么理解:对于历史上的某个状念动作对 ,距离第 步间隔了 步, 在 回报 中的权重为 ,并且 ,所以 是以 的比率折算到 中.间隔的步数每增加一步,原先的资格迹大致需要衰减为 倍.对当前最新出现的状态动作对 ,大的更新权重则要进行某种强化.强化的强度 常有以下取值:
- , 这时的资格迹称为累积迹(accumulating trace)
- (其中 是学习率),这时的资格迹称为荷兰迹(dutch trace)
- ,这时的资格迹称为替换迹(replacing trace).
当 时,直接将其资格迹加 1 ;当 时,资格迹总是取值在 范围内,所以让其资格迹直接等于 1 也实现了增加,只是增加的幅度没有 时那么大;当 时,增加的幅度在 和 之间.

利用资格迹,可以得到 策略评估算法.算法 3-10 给出了用 评估动作价值的算法.它是在单步时序差分的基础上,加入资格迹来实现的.资格迹也可以和最优策略求解算法结合,例如和 算法结合得到 算法.算法 3-10 中如果没有策略输入,在选择动作时按照当前动作价值来选择最优动作,就是 算法.
算法 3-10 的动作价值评估或 学习
输入:环境(无数学描述),若评估动作价值则需输入策略 . 输出:动作价值估计 参数:资格迹参数 ,优化器(隐含学习率 ),折扣因子 ,控制回合数和回合内步数的参数
- (初始化)初始化价值估计 任意值, .如果有终止状态,令
- 对每个回合执行以下操作:
2.1 (初始化资格迹)
2.2 (初始化状态动作对)选择状态 ,再根据输入策略 确定动作
2.3 如果回合未结東(比如未达到最大步数、S不是终止状态),执行以下操作:
- (采样)执行动作 ,观测得到奖励 和新状态
- 根据输入策略 或是迭代的最优价值函数 确定动作
- (更新资格迹)
- (计算回报的估计值)
- (更新价值)
- 若 ,则退出 2.2 步;否则 .
资格迹也可以用于状态价值.给定轨迹 ,资格迹 来表示第 步的状态动作对 对的单步自益结果 对每个状态 需要更新的权重,其定义为:当 时, 当 时, 算法 3-11 给出了用资格迹评估策略状态价值的算法.
算法 5-14 更新评估策略的状态价值
输入:环境(无数学描述)、策略 输出:状态价值函数 参数:资格迹参数 ,优化器(隐含学习率 ),折扣因子 ,控制回合数和回合内步数的参数
- (初始化)初始化价值 任意值, .如果有终止状态,
- 对每个回合执行以下操作:
2.1 (初始化资格迹)
2.2 (初始化状态)选择状态 S.
2.3 如果回合未结東(比如未达到最大步数、 不是终止状态),执行以下操作:
- 根据输入策略 确定动作
- (采样)执行动作 ,观测得到奖励 和新状态
- (更新资格迹)
- (计算回报的估计值)
- (更新价值)
算法与离线 回报算法相比,具有三大优点:
- 算法既可以用于回合制任务,又可以用于连续性任务
- 算法在每一步都更新价值估计,能够及时反映变化
- 算法在每一步都有均匀的计算,而且计算量都较小
四. 函数近似方法
第 章中介绍的有模型数值迭代算法、回合更新算法和时序差分更新算法,在每次更新价值函数时都只更新某个状态(或状态动作对)下的价值估计.但是,在有些任务中,状态和动作的数目非常大,甚至可能是无穷大,这时,不可能对所有的状态 (或状态动作对) 逐一进行更新.函数近似方法用参数化的模型来近似整个状态价值函数(或动作价值函数),并在每次学习时更新整个函数.这样,那些没有被访问过的状态(或状态动作对)的价值估计也能得到更新.本章将介绍函数近似方法的一般理论,包括策略评估和最优策略求解的一般理论.再介绍两种最常见的近似函数:线性函数和人工神经网络.后者将深度学习和强化学习相结合,称为深度 Q 学习,是第一个深度强化学习算法,也是目前的热门算法.
4.1 函数近似原理
本节介绍用函数近似(function approximation)方法来估计给定策略 的状态价值函数 或动作价值函数 .要评估状态价值,我们可以用一个参数为 的函数 来近似状态价值;要评估动作价值,我们可以用一个参数为 的函数 来近似动作价值.在动作集 有限的情况下,还可以用一个矢量函数一个动作,而整个矢量函数除参数外只用状态作为输人.这里的函数 、 形式不限,可以是线性函数,也可以是神经网络.但是,它们的形式要事先给定,在学习过程中只更新参数 .一旦参数 完全确定,价值估计就完全给定.所以,本节将介绍如何更新参数 .更新参数的方法既可以用于策略价值评估,也可以用于最优策略求解.
4.1.1 随机梯度下降
本节来看同策回合更新价值估计.将同策回合更新价值估计与函数近似方法相结合,可以得到函数近似回合更新价值估计算法(算法 4-1 ).这个算法与第 2 章中回合更新算法的区别就是在价值更新时更新的对象是函数参数,而不是每个状态或状态动作对的价值估计.
算法 6-1: 随机梯度下降函数近似评估策略的价值
- (初始化)任意初始化参数
- 逐回合执行以下操作
2.1 (采样)用环境和策略 生成轨迹样本
2.2 (初始化回报)
2.3 (逐步更新)对 ,执行以下步骤
- (更新回报)
- (更新价值)若评估的是动作价值则更新 以减小 (如 若评估的是状态价值则更新 以减小 .
如果我们用算法 4-1 评估动作价值,则更新参数时应当试图减小每一步的回报估计 和动作价值估计 的差别.所以,可以定义每一步损失为 ,而整个回合的损失为 .如果我们沿着 对 的梯度的反方向更新策略参数 ,就有机会减小损失.这样的方法称为随机梯度下降( stochastic gradient-descent, SGD )算法.对于能支持自动梯度计算的软件包,往往自带根据损失函数更新参数的功能.如果不使用现成的参数更新软件包,也可以自己计算得到 的 梯度 ,然后利用下式进行更新 : 对于状态价值函数,也有类似的分析.定义每一步的损失为 ,整个回合的损失为 .可以在自动梯度计算并更新参数的软件包中定义这个损失来更新参数 , 也可以用下式更新: 相应的回合更新策略评估算法与算法 4-1 类似,此处从略.
将策略改进引入随机梯度下降评估策略,就能实现随机梯度下降最优策略求解.算法 4-2 给出了随机梯度下降最优策略求解的算法.它与第 2 章回合更新最优策略求解算法的区别也仅仅在于迭代的过程中不是直接修改价值估计,而是更新价值参数 .
算法 4-2: 随机梯度下降求最优策略
- (初始化)任意初始化参数
- 逐回合执行以下操作
2.1 (采样)用环境和当前动作价值估计 导出的策略(如 柔性策略)生成轨迹样
本
2.2 (初始化回报)
2.3 (逐步更新)对 ,执行以下步骤:
- (更新回报)
- (更新动作价值函数)更新参数 以减小 如 .
4.1.2 半梯度下降
动态规划和时序差分学习都用了“自益”来估计回报,回报的估计值与 有关,是存在偏差的.例如,对于单步更新时序差分估计的动作价值函数,回报的估计为 ,而动作价值的估计为 ,这两个估计都与权重 有关.在试图减小每一步的回报估计 和动作价值估计 的差别时,可以定义每一步损失为 ,而整个回合的损失为 .在更新参数 以减小损失时,应当注意不对回报的估计 求梯度,只对动作价值的估计 求关于 的梯度,这就是半梯度下降(semi-gradient descent)算法.半梯度下降算法同样既可以用于策略评估,也可以用于求解最优策略(见算法 4-3 和算法 4-4 ).
算法 4-3: 半梯度下降算法估计动作价值或 算法求最优策略
- (初始化)任意初始化参数
- 逐回合执行以下操作
2.1 (初始化状态动作对)选择状态 .如果是策略评估,则用输入策略 确定动作 ;如果是寻找最优策略,则用当前动作价值估计 导出的策略 如 柔性策略 确定动作
2.2 如果回合未结東,执行以下操作:
- (采样)执行动作 ,观测得到奖励 和新状态
- 如果是策略评估,则用输入策略 确定动作 ;如果是寻找最优策略,则用当前动作价值估计 导出的策略 如 柔性策略)确定动作
- (计算回报的估计值)
- (更新动作价值函数)更新参数 以减小 如 .注意此步不可以重新计算
算法 4-4: 半梯度下降估计状态价值或期望 SARSA 算法或 学习
- (初始化)任意初始化参数
- 逐回合执行以下操作
2.1 (初始化状态)选择状态
2.2 如果回合未结束,执行以下操作
- 如果是策略评估,则用输入策略 确定动作 ;如果是寻找最优策略,则用当前动作价值估计 导出的策略(如 柔性策略)确定动作
- (采样)执行动作 ,观测得到奖励 和新状态
- (计算回报的估计值)如果是状态价值评估,则 .如果是期望 SARSA 算法,则 ,其中 是 确定的策略(如 柔性策略)若是 学习则
- (更新动作价值函数)若是状态价值评估则更新 以减小 (如 ,若是期望 算法或 学习则更新参数 以减小 如 .注意此步不可以重新计算
- .
如果采用能够自动计算微分并更新参数的软件包来减小损失,则务必注意不能对回报的估计求梯度.有些软件包可以阻止计算过程中梯度的传播,也可以在计算回报估计的表达式时使用阻止梯度传播的功能.还有一种方法是复制一份参数 ,在计算回报估计的表达式时用这份复制后的参数 来计算回报估计,而在自动微分时只对原来的参数进行微分,这样就可以避免对回报估计求梯度.
4.1.3 带资格迹的半梯度下降
在第 3 章中,我们学习了资格迹算法.资格迹可以在回合更新和单步时序差分更新之间进行折中,可能获得比回合更新或单步时序差分更新都更好的结果.回顾前文,在资格迹算法中,每个价值估计的数值都对应着一个资格迹参数,这个资格迹参数表示这个价值估计数值在更新中的权重.最近遇到的状态动作对(或状态)的权重大,比较久以前遇到的状态动作对(或状态)的权重小,从来没有遇到过的状态动作对(或状态)的权重为 0 .每次更新时,都可以更新整条轨迹上的资格迹,再利用资格迹作为权重,更新整条轨迹上的价值估计.
资格迹同样可以运用在函数近似算法中,实现回合更新和单步时序差分的折中.这时,资格迹对应价值参数 .具体而言,资格迹参数 和价值参数 具有相同的形状大小,并且逐元素一一对应.资格迹参数中的每个元素表示了在更新价值参数对应元素时应当使用的权重乘以价值估计对该分量的梯度.也就是说,在更新价值参数 的某个分量 对应着资格迹参数 中的某个分量 时,那么在更新 时应当使用以下迭代式更新: 对价值参数整体而言,就有 当选取资格迹为累积迹时,资格迹的递推定义式如下:当 时 当 时 资格迹的递推式由 2 项组成.递推式的第一项是对前一次更新时使用的资格迹衰减而来,衰减系数是 ,这是一个 0 到 1 之间的数.可以通过改变 的值,决定衰减的速度. 当 接近 0 时,衰减快;当 接近 1 时,衰减慢.递推式的第二项是加强项,它由动作价值的梯度值决定.动作价值的梯度值事实上确定了价值参数对总体价值估计的影响.对总体价值估计影响大的那些价值参数分量是当前比较重要的分量,应当加强它的资格迹.不过,梯度的分量值不一定是正数或 0 ,也可能是负数.所以,更新后的资格迹分量也可能是负值.当资格迹的某些分量是负值时,对应价值参数分量的权重值就是负值.进一步而言,在价值参数更新时,面对相同的时序差分误差,会出现价值参数的某些分量增大而另一些 分量减小的情况. 算法 4-5 和算法 4-6 给出了使用资格迹的价值估计和最优策略求解算法.这两个算法都 使用了累积迹.
算法 6-5 算法估计动作价值或 算法
- (初始化)任意初始化参数
- 逐回合执行以下操作
2.1 (初始化状态动作对)选择状态 如果是策略评估,则用输入策略 确定动作 ;如果是寻找最优策略,则用当前动作价值估计 导出的策略(如 柔性策略 确定动作
2.2 如果回合未结東,执行以下操作
- (采样)执行动作 ,观测得到奖励 和新状态
- 如果是策略评估,则用输入策略 确定动作 ;如果是寻找最优策略,用当前动作价值估计 导出的策略 如 柔性策略 确定动作 ;
- (计算回报的估计值)
- (更新资格迹)
- (更新动作价值函数)
- .
算法 4-6 TD(\lambda) 估计状态价值或期望 SARSA(\lambda) 算法或 学习
- (初始化)任意初始化参数
- 逐回合执行以下操作
2.1 (初始化资格迹)
2.2 (初始化状态)选择状态 S
2.3 如果回合未结束,执行以下操作
- 如果是策略评估,则用输入策略 确定动作 ;如果是寻找最优策略,用当前动作价值估计 导出的策略 如 柔性策略 确定动作
- (采样)执行动作 ,观测得到奖励 和新状态
- (计算回报的估计值)如果是状态价值评估,则 .如果是期望 SARSA 算法,则 ,其中 是 确定的策略 (如 柔性策略 .若是 学习则
- (更新资格迹)若是状态价值评估,则 ;若是期望 法或 学习,则
- (更新动作价值函数)若是状态价值评估,则 ;若是期望 算法或 学习,则
- .
TODO:线性近似
TODO:4.3函数近似的收敛性
4.4 深度Q学习
本节介绍一种目前非常热门的函数近似方法——深度 Q 学习.深度 Q 学习将深度学习和强化学习相结合,是第一个深度强化学习算法.深度 Q 学习的核心就是用一个人工神经网络 来代替动作价值函数.由于神经网络具有强大的表达能力,能够自动寻找特征,所以采用神经网络有潜力比传统人工特征强大得多.最近基于深度 Q 网络的深度强化学习算法有了重大的进展,在目前学术界有非常大的影响力.当同时出现异策、自益和函数近似时,无法保证收敛性,会出现训练不稳定或训练困难等问题.针对出现的各种问题,研究人员主要从以下两方面进行了改进.
- 经验回放(experience replay): 将经验(即历史的状态、动作、奖励等)存储起来,再在存储的经验中按一定的规则采样.
- 目标网络(target network): 修改网络的更新方式,例如不把刚学习到的网络权重马上用于后续的自益过程.
本节后续内容将从这两条主线出发,介绍基于深度 Q 网络的强化学习算法.
4.4.1 经验回放
V. Mnih 等在 2013 年发表文章《Playing Atari with deep reinforcement learning 》,提出了基于经验回放的深度 Q 网络,标志着深度 Q 网络的诞生,也标志着深度强化学习的诞生.在 4.2 节中我们知道,采用批处理的模式能够提供稳定性.经验回放就是一种让经验的概率分布变得稳定的技术,它能提高训练的稳定性.经验回放主要有“存储”和“采样回放”两大关键步骤.
- 存储:将轨迹以 等形式存储起来;
- 采样回放:使用某种规则从存储的 中随机取出一条或多条经验
算法 4-8 给出了带经验回放的 Q 学习最优策略求解算法
算法 6-8 带经验回放的 学习最优策略求解
- (初始化)任意初始化参数
- 逐回合执行以下操作
2.1 (初始化状态)选择状态
2.2 如果回合未结東,执行以下操作
- (采样)根据 选择动作 并执行,观测得到奖励 和新状态
- (存储)将经验 存入经验库中
- (回放)从经验库中选取经验
- (计算回报的估计值)
- (更新动作价值函数)更新 以减小 (如
经验回放有以下好处.
- 在训练 网络时,可以消除数据的关联,使得数据更像是独立同分布的(独立同分布是很多有监督学习的证明条件)这样可以减小参数更新的方差,加快收敛.
- 能够重复使用经验,对于数据获取困难的情况尤其有用.
从存储的角度,经验回放可以分为集中式回放和分布式回放.
- 集中式回放:智能体在一个环境中运行,把经验统一存储在经验池中.
- 分布式回放:智能体的多份拷贝(worker) 同时在多个环境中运行,并将经验统一存储于经验池中.由于多个智能体拷贝同时生成经验,所以能够在使用更多资源的同 时更快地收集经验
从采样的角度,经验回放可以分为均匀回放和优先回放.
- 均匀回放:等概率从经验集中取经验,并且用取得的经验来更新最优价值函数
- 优先回放(Prioritized Experience Replay, PER):为经验池里的每个经验指定一个优先级,在选取经验时更倾向于选择优先级高的经验.
T. Schaul 等于 2016 年发表文章 《Prioritized experience replay》,提出了优先回放.优先回放的基本思想是为经验池里的经验指定一个优先级,在选取经验时更倾向于选择优先级高的经验.一般的做法是,如果某个经验(例如经验 )的优先级为 ,那么选取该经验的概率为 经验值有许多不同的选取方法,最常见的选取方法有成比例优先和基于排序优先.
- 成比例优先(proportional priority):第 个经验的优先级为 .其中 是时序差分误差 定义为 或 是预先选择 的一个小正数, 是正参数.
- 基于排序优先(rank-based priority): 第 个经验的优先级为 .其中 是第 个经验从大到小排序的排名,排名从 1 开始.
D. Horgan 等在 2018 发表文章 《 Distributed prioritized experience replay》,将分布式经 签回放和优先经验回放相结合,得到分布式优先经验回放(distributed prioritized experience replay).
4.4.2 带目标网络的深度Q学习
对于基于自益的 Q 学习,其回报的估计和动作价值的估计都和权重 有关.当权重值变化时,回报的估计和动作价值的估计都会变化.在学习的过程中,动作价值试图追逐一个变化的回报,也容易出现不稳定的情况.在 4.1.2 节中给出了半梯度下降的算法来解决 这个问题.在半梯度下降中,在更新价值参数 时,不对基于自益得到的回报估计 求梯 度.其中一种阻止对 求梯度的方法就是将价值参数复制一份得到 ,在计算 时用 计算.基于这一方法,V. Mnih 等在 2015 年发表了论文《Human-level control through deep reinforcement learning》,提出了目标网(target network)这一概念.目标网络是在原有的神经网络之外再搭建一份结构完全相同的网络.原先就有的神经网络称为评估网络(evaluation network).在学习的过程中,使用目标网络来进行自益得到回报的评估值,作为学习的目标.在权重更新的过程中,只更新评估网络的权重,而不更新目标网络的权重.这样,更新权重时针对的目标不会在每次迭代都变化,是一个固定的目标.在完成一定次数的更新后,再将评估网络的权重值赋给目标网络,进而进行下一批更新.这样,目标网络也能得到更新.由于在目标网络没有变化的一段时间内回报的估计是相对固定的,目标网络的引入增加了学习的稳定性.所以,目标网络目前已经成为深度 Q 学习的主流做法.
算法 4-9 给出了带目标网络的深度 Q 学习算法.算法开始时将评估网络和目标网络初始化为相同的值.为了得到好的训练效果,应当按照神经网络的相关规则仔细初始化神经网络的参数.
算法 4-9 带经验回放和目标网络的深度 学习最优策略求解
- (初始化)初始化评估网络 的参数 ;目标网络 的参数
- 逐回合执行以下操作
2.1 (初始化状态)选择状态
2.2 如果回合未结束,执行以下操作:
- (采样)根据 选择动作 并执行,观测得到奖励 和新状态
- (经验存储)将经验 存入经验库 中
- (经验回放)从经验库 中选取一批经验
- (计算回报的估计值)
- (更新动作价值函数)更新 以减小 (如
- (更新目标网络)在一定条件下(例如访问本步若千次)更新目标网络的权重
在更新目标网络时,可以简单地把评估网络的参数直接赋值给目标网络 即 ,也可以引人一个学习率 把旧的目标网络参数和新的评估网络参数直接做加权平均后的值赋值给目标网络 即 .事实上,直接赋值的版本是带学习率版本在 时的特例.对于分布式学习的情形,有很多独立的拷贝(worker)同时会修改目标网络,则就更常用学习率 .
4.4.3 双重深度Q学习
第 3 章曾提到 学习会带来最大化偏差,而双重 学习却可以消除最大化偏差.基于查找表的双重 Q 学习引入了两个动作价值的估计 和 ,每次更新动作价值时用其中的一个网络确定动作,用确定的动作和另外一个网络来估计回报.
对于深度 Q 学习也有同样的结论.Deepmind 于 2015 年发表论文 《 Deep reinforcement learning with double Q-learning 》,将双重 Q 学习用于深度 Q 网络,得到了双重深度 Q 网络(Double Deep Q Network, Double ).考虑到深度 Q 网络已经有了评估网络和目标网络两个网络,所以双重深度 Q 学习在估计回报时只需要用评估网络确定动作,用目标网络确定回报的估计即可.所以,只需要将算法 4-10 中的 更换为 就得到了带经验回放的双重深度 Q 网络算法.
4.4.4 对偶深度 Q 网络
Z. Wang 等在 2015 年发表论文《Dueling network architectures for deep reinforcement learning 》,提出了一种神经网络的结构——对偶网络( duel network).对偶网络理论利用动作价值函数和状态价值函数之差定义了一个新的函数——优势函数(advantage function): 对偶 网络仍然用 来估计动作价值,只不过这时候 是状态价值估计 和优 势函数估计 的叠加,即 其中 和 可能都只用到了 中的部分参数.在训练的过程中, 和 是共 同训练的,训练过程和单独训练普通深度 Q 网络并无不同之处.不过,同一个 事实上存在着无穷多种分解为 和 的方式.如果某个 可以分解为某个 和 ,那么它也能分解为 和 ,其中 是任意一个只和状态 有关的函数.为了不给训练带来不必要的麻烦,往往可以通过增加一个由优势函数导出的量,使得等效的优势函数满足固定的特征,使得分解唯一.常见的方法有以下两种:
- 考虑优势函数的最大值,令 使得等效优势函数 满足
- 考虑优势函数的平均值,令 使得等效优势函数 满足
五. 回合更新策略梯度方法
本书前几章的算法都利用了价值函数,在求解最优策略的过程中试图估计最优价值函数,所以那些算法都称为最优价值算法(optimal value algorithm).但是,要求解最优策略不一定要估计最优价值函数.本章将介绍不直接估计最优价值函数的强化学习算法,它们试图用含参函数近似最优策略,并通过迭代更新参数值.由于迭代过程与策略的梯度有关,所以这样的迭代算法又称为策略梯度算法(policy gradient algorithm).
5.1 策略梯度算法的原理
基于策略的策略梯度算法有两大核心思想:
- 用含参函数近似最优策略
- 用策略梯度优化策略参数
本节介绍这两部分内容.
5.1.1 函数近似与动作偏好
用函数近似方法估计最优策略 的基本思想是用含参函数 来近似最优策略.由于任意策略 都需要满足对于任意的状态 ,均有 ,我们也希望 满足对于任意的状态 ,均有 .为此引入动作偏好函数(action preference function) ,其 softmax 的值为 ,即 在第 3~4 章中,从动作价值函数导出最优策略估计往往有特定的形式 (如 贪心策 略).与之相比,从动作偏好导出的最优策略的估计不拘泥于特定的形式,其每个动作都可以有不同的概率值,形式更加灵活.如果采用迭代方法更新参数 ,随着迭代的进行, 可以自然而然地逼近确定性策略,而不需要手动调节 等参数.
动作偏好函数可以具有线性组合、人工神经网络等多种形式.在确定动作偏好的形式中,只需要再确定参数 的值,就可以确定整个最优状态估计.参数 的值常通过基于梯度的迭代算法更新,所以,动作偏好函数往往需要对参数 可导.
5.1.2 策略梯度定理
策略梯度定理给出了期望回报和策略梯度之间的关系,是策略梯度方法的基础.本节学习策略梯度定理.
在回合制任务中,策略 期望回报可以表示为 .策略梯度定理(policy gradient theorem) 给出了它对策略参数 的梯度为 其等式右边是和的期望,求和的 中,只有 显式含有参数 .
策略梯度定理告诉我们,只要知道了 的值,再配合其他一些容易获得的 值(如 和 ,就可以得到期望回报的梯度.这样,我们也可以顺着梯度方向改变 以增大期望回报.
接下来我们来证明这个定理.回顾,策略 满足 期望方程,即 将以上两式对 求梯度,有 将 的表达式代人 的表达式中,有 在策略 下,对 求上式的期望,有 这样就得到了从 到 的递推式.注意到最终关注的梯度值就是 所以有 考虑到 所以 又由于 ,所以 得证.
5.2 同策回合更新策略梯度算法
策略梯度定理告诉我们,沿着 的方向改变策略参数 的值,就有机会增加期望回报.基于这一结论,可以设计策略梯度算法.本节考虑同策更新算法
5.2.1 简单的策略梯度算法
在每一个回合结束后,我们可以就回合中的每一步用形如 的迭代式更新参数 .这样的算法称为简单的策略梯度算法(Vanilla Policy Gradient, VPG).
R Willims 在文章《Simple statistical gradient-following algorithms for connectionist reinforcement learning 》中给出了该算法,并称它为“REward Increment = Nonnegative Factor Offset Reinforcement Characteristic Eligibility” ( ,表示增量 是由三个部分的积组成的.这样迭代完这个回合轨迹就实现了 在具体的更新过程中,不一定要严格采用这样的形式.当采用 TensorFlow 等自动微分的软件包来学习参数时,可以定义单步的损失为 ,让软件包中的优化器减小整个回合中所有步的平均损失,就会沿着 的梯度方向改变 的值.
简单的策略梯度算法见算法 5-1.
算法 5-1: 简单的策略梯度算法求解最优策略
输入:环境(无数学描述) 输出:最优策略的估计 参数:优化器(隐含学习率 ),折扣因子 ,控制回合数和回合内步数的参数
- (初始化) 任意值
- (回合更新)对每个回合执行以下操作
2.1 (采样)用策略 生成轨迹
2.2 (初始化回报)
2.3 对 ,执行以下步骤:
- (更新回报)
- (更新策略)更新 以减小
5.2.2 带基线的简单策略梯度算法
本节介绍简单的策略梯度算法的一种改进一带基线的简单的策略梯度算法(REINFOCE with baselines).为了降低学习过程中的方差,可以引人基线函数 .基线函数 可以是任意随机函数或确定函数,它可以与状态 有关,但是不能和动作 有关.满足这样的条件后,基线函数 自然会满足 证明如下:由于 与 无关,所以 进而 得证. 基线函数可以任意选择,例如以下情况
- 选择基线函数为由轨迹确定的随机变量 ,这时 ,梯度的形式为
- 选择基线函数为 ,这时梯度的形式 为
但是,在实际选择基线时,应当参照以下两个思想.
- 基线的选择应当有效降低方差.一个基线函数能不能降低方差不容易在理论上判别, 往往需要通过实践获知.- 基线函数应当是可以得到的.例如我们不知道最优价值函数,但是可以得到最优价值函数的估计.价值函数的估计也可以随着迭代过程更新.
一个能有效降低方差的基线是状态价值函数的估计.算法 5-2 给出了用状态价值函数的估计作为基线的算法.这个算法有两套参数 和 ,分别是最优策略估计和最优状态价值函数估计的参数.每次迭代时,它们都以各自的学习算法进行学习.算法 5-2 采用了随机梯度下降法来更新这两套参数(事实上也可以用其他算法),在更新过程中都用到了 ,可以在更新前预先计算以减小计算量.
算法 5-2: 带基线的简单策略梯度算法求解最优策略
输入:环境(无数学描述) 输出:最优策略的估计 参数:优化器(隐含学习率 ,折扣因子 ,控制回合数和回合内步数的参数.
- (初始化) 任意值, 任意值.
- (回合更新)对每个回合执行以下操作:
2.1 (采样)用策略 生成轨迹
2.2 (初始化回报)
2.3 对 ,执行以下步骤:
- (更新回报)
- (更新价值)更新 以减小 如
- (更新策略)更新 以减小 如
接下来,我们来分析什么样的基线函数能最大程度地减小方差.考虑 的方差为 其对 求偏导数为 (求偏导数时用到了 ).令这个偏导数为 0 ,并假设 可知 这意味着,最佳的基线函数应当接近回报 以梯度 为权重加权平均的结果.但是,在实际应用中,无法事先知道这个值,所以无法使用这样的基线函数.
值得一提的是,当策略参数和价值参数同时需要学习的时候,算法的收敛性需要通过双时间轴 Robbins-Monro 算法(two timescale Robbins-Monro algorithm)来分析.
5.3 异策回合更新策略梯度算法
在简单的策略梯度算法的基础上引入重要性采样,可以得到对应的异策算法.记行为策略为 ,有 即 所以,采用重要性采样的离线算法,只需要把用在线策略采样得到的梯度方向 改为用行为策略 采样得到的梯度方向 即可.这就意味着,在更新参数 时可以试图增大 .
算法5-3: 重要性采样简单策略梯度求解最优策略
- (初始化) 任意值
- (回合更新)对每个回合执行以下操作:
2.1 (行为策略)指定行为策略 ,使得
2.2 (采样)用策略 生成轨迹:
2.3 (初始化回报和权重)
2.4 对 ,执行以下步骤:
- (更新回报)
- (更新策略)更新参数 以减小 如
重要性采样使得我们可以利用其他策略的样本来更新策略参数,但是可能会带来较大的偏差,算法稳定性比同策算法差.
5.4 策略梯度更新和极大似然估计的关系
至此,本章已经介绍了各种各样的策略梯度算法.这些算法在学习的过程中,都是通过更新策略参数 以试图增大形如 的目标(考虑单个条目则为 ,其中 可取 等值.将这一学习过程与下列有监督学习最大似然问题的过程进行比较,如果已经有一个表达式未知的策略 ,我们要用策略 来近似它,这时可以考虑用最大似然的方法来估计策略参数 .具体而言,如果已经用未知策略 生成了很多样本,那么这些样本对于策略 的对数似然值正比于 .用这些样本进行有监督学习,需要更新策略参数 以增大 (考虑单个条目则为 .可以看出, 可以通过 中取 得到,在形式上具有相似性.策略梯度算法在学习的过程中巧妙地利用观测到的奖励信号决定每步对数似然值 对策略奖励的贡献,为其加权 (这里的 可能是正数,可能是负数,也可 能是 0 ),使得策略 能够变得越来越好.注意,如果取 ,在整个回合中是不变的(例如 ,那么在单一回合中的 就是对整个回合的对数似然值进行加权后对策略的贡献,使得策略 能够变得越来越好.试想,如果有的回合表现很好 (比如 是很大的正数 ),在策略梯度更新的时候这个回合的似然值 就会有一个比较大的权重 例如 ,这样这个表现比较好的回合就会更倾向于出现;如果有的回合表现很差(比如 是很小的负数,即绝对值很大的负数)则策略梯度更新时这个回合的似然值就会有比较小的权重,这样这个表现较差的回合就更倾向于不出现.
六. 执行者/评论者方法
本章介绍带自益的策略梯度算法.这类算法将策略梯度和自益结合了起来:一方面,用一个含参函数近似价值函数,然后利用这个价值函数的近似值来估计回报值;另一方面,利用估计得到的回报值估计策略梯度,进而更新策略参数.这两方面又常常被称为评论者 (critic) 和执行者(actor).所以,带自益的策略梯度算法被称为执行者 / 评论者算法(actorcritic algorithm).
6.1 执行者 / 评论者算法
同样用含参函数 表示偏好,用其 运算的结果 来近似最优策略.在更新参数 时,执行者 / 评论者算法依然也是根据策略梯度定理,取 为梯度方向迭代更新.其中, Schulman 等在文章《 High-dimensional continuous control using generalized advantage estimation 》中指出, 并不拘泥于以上形式. 可以是以下几种形式:
- (动作价值)
- (优势函数)
- (时序差分)
在以上形式中,往往用价值函数来估计回报.例如,由于 ,而且 也表征期望方向,所以 ,相当于用 表示期望.再例如,对于 ,就相当于在回报 的基础上减去基线 以减小方差.对于时序差分 ,也是用 代表回报,再减去基线 以减小方差.
不过在实际使用时,真实的价值函数是不知道的.但是,我们可以去估计这些价值函数.具体而言,我们可以用函数近似的方法,用含参函数 或 来近似 和 .在上一章中,带基线的简单策略梯度算法已经使用了含参函数 作为基线函数.我们可以在此基础上进一步引人自益的思想,用价值的估计 来代替 中表示回报的部分.例如,对于时序差分,用估计来代替价值函数可以得到 .这里的估计值 就是评论者,这样的算法就是执行者 / 评论者算法. >注意:只有采用了自益的方法,即用价值估计来估计回报,并引入了偏差,才是执行者 / 评论者算法.用价值估计来做基线并没有带来偏差(因为基线本来就可以任意选择).所以,带基线的简单策略梯度算法不是执行者 / 评论者算法.
6.1.1 动作价值执行者 / 评论者算法
根据前述分析,同策执行者 / 评论者算法在更新策略参数 时也应该试图减小 ,只是在计算 时采用了基于自益的回报估计.算法 6-1 给出了在回报估计为 ,并取 时的同策算法,称为动作价值执行者 / 评论者算法.算法一开始初始化了策略参数和价值参数.虽然算法中写的是可以将这个参数初始化为任意值,但是如果它们是神经网络的参数,还是应该按照神经网络的要求来初始化参数.在迭代过程中有个变量 ,用来存储策略梯度的表达式中的折扣因子 .在同一回合中,每一步都把这个折扣因子乘上 ,所以第 步就是 .
算法 6-1: 动作价值同策执行者 / 评论者算法
输入:环境(无数学描述) 输出:最优策略的估计 参数:优化器(隐含学习率 ,折扣因子 ,控制回合数和回合内步数的参数.
- (初始化) 任意值, 任意值
- (带自益的策略更新)对每个回合执行以下操作:
2.1 (初始化累积折扣)
2.2 (决定初始状态动作对)选择状态 ,并用 得到动作
2.3 如果回合未结束,执行以下操作:
- (采样)根据状态 和动作 得到奖励 和下一状态
- (执行)用 得到动作
- (估计回报)
- (策略改进)更新 以减小 如
- (更新价值)更新 以减小 如
- (更新累积折扣)
- (更新状态) .
6.1.2 优势执行者 / 评论者算法
在基本执行者 / 评论者算法中引入基线函数 ,就会得到 ,其中, 是优势函数的估计.这样,我们就得到了优势执行者 评论者算法.不过,如果采用 这样形式的优势函数估计值,我们就需搭建两个函数分别表示 和 .为了避免这样的麻烦,这里用了 做目标,这样优势函数的估计就变为单步时序差分的形式 .
如果优势执行者 / 评论者算法在执行过程中不是每一步都更新参数,而是在回合结束后用整个轨迹来进行更新,就可以把算法分为经验搜集和经验使用两个部分.这样的分隔可以让这个算法同时有很多执行者在同时执行.例如,让多个执行者同时分别收集很多经验,然后都用自己的那些经验得到一批经验所带来的梯度更新值.每个执行者在一定的时机更新参数,同时更新策略参数 和价值参数 .每个执行者的更新是异步的.所以,这样的并行算法称为异步优势执行者 / 评论者算法( Asynchronous Advantage Actor-Critic, ).异步优势执行者 / 评论者算法中的自益部分,不仅可以采用单步时序差分,也可以使用多步时序差分.另外,还可以对函数参数的访问进行控制,使得所有执行者统一更新参数.这样的并行算法称为优势执行者 / 评论者算法(Advantage Actor-Critic, ).算法 给出了异步优势执行者 / 评论者算法.异步优势执行者 / 评论者算法可以有许多执行者 (或称多个线程 ),所以除了有全局的价值参数 和策略参数 外,每个线程还可能有自己维护的价值参数 和 .执行者执行时,先从全局同步参数,然后再自己学习,最后统一同步全局参数.
算法 6-2: 优势执行者 / 评论者算法
输入:环境(无数学描述) 输出:最优策略的估计 参数:优化器(隐含学习率 ,折扣因子 ,控制回合数和回合内步数的参数.
- (初始化) 任意值, 任意值.
- (带自益的策略更新)对每个回合执行以下操作:
2.1 (初始化累积折扣)
2.2 (决定初始状态)选择状态
2.3 如果回合未结束,执行以下操作:
- (采样)用 得到动作
- (执行)执行动作 ,得到奖励 和观测
- (估计回报)
- (策略改进)更新 以减小 如
- (更新价值)更新 以减小 如
- (更新累积折扣)
- (更新状态).
算法 6-3: 异步优势执行者 / 评论者算法 (演示某个线程的行为)
输入:环境(无数学描述) 输出:最优策略的估计 参数:优化器(隐含学习率 ,折扣因子 ,控制回合数和回合内步数的参数
- (同步全局参数)
- 逐回合执行以下过程:
2.1 用策略 生成轨迹 ,直到回合结束或执行步数达
到上限
2.2 为梯度计算初始化:
- (初始化目标 )若 是终止状态,则 ;否则
- (初始化梯度) 2.3 (异步计算梯度)对 ,执行以下内容:
- (估计目标)计算
- (估计策略梯度方向)
- (估计价值梯度方向)
- (同步更新)更新全局参数 3.1 (策略更新)用梯度方向 更新策略参数 如 3.2 (价值更新)用梯度方向 更新价值参数 如 .
6.1.3 带资格迹的执行者 / 评论者方法
执行者 / 评论者算法引入了自益,那么它也就可以引入资格迹.算法 6-4 给出了带资格迹的优势执行者 / 评论者算法.这个算法里有两个资格迹 和 ,它们分别与策略参数 和价值参数 对应,并可以分别有自己的 和 .具体而言, 与价值参数 对应,运用梯度为 ,参数为 的累积迹; 与策略参数 对应,运用的梯度是 参数为 的累积迹,在运用中可以将折扣 整合到资格迹中.
算法 6-4: 带资格迹的优势执行者 / 评论者算法
输入:环境(无数学描述) 输出:最优策略的估计 参数:资格迹参数 ,学习率 ,折扣因子 ,控制回合数和回合内步数的参数.
- (初始化) 任意值, 任意值.
- (带自益的策略更新)对每个回合执行以下操作:
2.1 (初始化资格迹和累积折扣)
2.2 (决定初始状态)选择状态
2.3 如果回合未结束,执行以下操作:
- (采样)用 得到动作
- (执行)执行动作 ,得到奖励 和观测
- (估计回报)
- (更新策略资格迹)
- (策略改进)
- (更新价值资格迹)
- (更新价值)
- (更新累积折扣)
- (更新状态) .
6.2 基于代理优势的同策算法
本节介绍面向代理优势的执行者 / 评论者算法.这些算法在迭代的过程中并没有直接优化期望目标,而是试图优化期望目标近似一代理优势.在很多问题上,这些算法会比简单的执行者 / 评论者算法得到更好的性能.
6.2.1 代理优势
考虑采用迭代的方法更新策略 .在某次迭代后,得到了策略 .接下来我们希望得到一个更好的策略 .Kakade 等在文章 《 Approximately optimal approximate reinforcement learning 》中证明了策略 和策略 的期望回报满足性能差别引理 (Performance Difference Lemma): (证明: 得证.)
所以,要最大化 ,就是要最大化优势的期望 .这个期望是对含参策略而言的.要优化这样的期望,可以利用以下形式的重采样,将其中对 求期望转化为对 求期望: 但是,对 求期望无法进一步转化.代理优势( surrogate advantage)就是在上述重采样的基础上,将对 求期望近似为对 求期望: 这样得到了 的近似表达式 ,其中 可以证明, 和 在 处有相同的值 和梯度.
虽然 没有直接的表达式而很难直接优化,但是只要沿着它的梯度方向改进策略参数,就有机会增大它.由于 和 在 处有着相同的值和梯度方向, 和代理优势有着相同的梯度方向.所以,沿着 的梯度方向就有机会改进 .据此,我们可以得到以下结论:通过优化代理优势,有希望找到更好的策略.
6.2.2 邻近策略优化
我们已经知道代理优势与真实的目标相比,在 处有相同的值和梯度.但是,如果 和 差别较远,则近似就不再成立.所以针对代理优势的优化不能离原有的策略太远.基于这一思想,J. Schulman 等在文章 《 Proximal policy optimization algorithms 》中提出了邻近策略优化 (Proximal Policy Optimization) 算法,将优化目标设计为 其中 是指定的参数.采用这样的优化目标后,优化目标至多比 大 ,所以优化问题就没有动力让代理优势 变得非常大,可以避免迭代后的策略与迭代前的策略差距过大.
算法 6-5 给出了邻近策略优化算法的简化版本.
算法 6-5: 邻近策略优化算法 (简化版本 )
输入:环境(无数学描述) 输出:最优策略的估计 参数:策略更新时目标的限制参数 ,优化器,折扣因子 ,控制回合数和回合内步数的参数.
- (初始化) 任意值, 任意值
- (时序差分更新)对每个回合执行以下操作: 2.1 用策略 生成轨迹 2.2 用生成的轨迹由 确定的价值函数估计优势函数 如 2.3 (策略更新)更新 以增大 2.4 (价值更新)更新 以减小价值函数的误差(如最小化
在实际应用中,常常加人经验回放.具体的方法是,每次更新策略参数 和价值参数 前得到多个轨迹,为这些轨迹的每一步估计优势和价值目标,并存储在经验库 中.接着多次执行以下操作:从经验库 中抽取一批经验 ,并利用这批经验回放并学习,即从经验库中随机抽取一批经验并用这批经验更新策略参数和价值参数.
注意:邻近策略优化算法在学习过程中使用的经验都是当前策略产生的经验,所以使用了经验回放的邻近策略优化依然是同策学习算法.
6.3 信任域算法
信任域方法(Trust Region Method, TRM)是求解非线性优化的常用方法,它将一个复杂的优化问题近似为简单的信任域子问题再进行求解.
本节将介绍三种同策执行者 / 评论者算法:
- 自然策略梯度算法
- 信任域策略优化算法
- Kronecker 因子信任域执行者 / 评论者算法
这三个算法十分接近,它们都是以试图通过优化代理优势,迭代更新策略参数,进而找到最优策略的估计.在优化的过程中,也需要让新的策略和旧的策略不能相差太远.和上节介绍的邻近策略优化相比,它们在代理优势的基础上可以进一步引入信任域,要求新的策略在一个信任域内.本节将介绍信任域的定义(包括用来定义信任域的 散度的定义),再介绍如何利用信任域实现这些算法.
6.3.1 KL 散度
我们先来看 散度的定义.回顾重要性采样的章节,我们知道,如果两个分布 和 ,满足对于任意的 ,均有 ,则称分布 对 分布 绝对连续,记为 .在这种情况下,我们可以定义从分布 到分布 的 KL 散度 (Kullback-Leibler divergence): 当 和 是离散分布时, 当 和 是连续分布时, 散度有个性质:相同分布的 散度为 0 ,即 .
TODO:信任域A-C算法
重要性采样异策执行者 / 评论者算法
执行者 / 评论者算法可以和重要性采样结合,得到异策执行者 / 评论者算法.本节介绍基于重要性采样的异策执行者 / 评论者算法.
6.4.1 基本的异策算法
本节介绍基于重要性采样的异策的执行者/评论者算法(Off-Policy Actor-Critic, OffPAC ).
用 表示行为策略,则梯度方向可由 变为 .这时,更新策略参数 时就应该试图减小 .据此,可以得到异策执行者 / 评论者算法,见算法 6-10.
算法 8-10: 异策动作价值执行者 / 评论者算法
输入:环境(无数学描述) 输出:最优策略的估计 参数:优化器(隐含学习率 ,折扣因子 ,控制回合数和回合内步数的参数.
- (初始化) 任意值, 任意值
- (带自益的策略更新)对每个回合执行以下操作:
2.1 (初始化累积折扣)
2.2 (初始化状态动作对)选择状态 ,用行为策略 得到动作
2.3 如果回合未结束,执行以下操作:
- (采样)根据状态 和动作 得到采样 和下一状态
- (执行)用 得到动作
- (估计回报)
- (策略改进)更新 以减小 如
- (更新价值)更新 以减小 如
- (更新累积折扣)
- (更新状态)
6.4.2 带经验回放的异策算法
本节介绍 Wang 等在文章 《 Sample efficient actor-critic with experience replay》中提出的带经验回放的执行者 / 评论者算法 ( Actor-Critic with Experiment Replay, .如果说 节介绍的基本异策执行者 / 评论者算法是 节介绍的基本同策执行者 / 评论者算法的异策版本,那么本节介绍的带经验回放的异策执行者 / 评论者算法就相当于 节介绍的 算法的异策版本.它同样可以支持多个线程的异步学习:每个线程在执行前先同步全局参数,然后独立执行和学习,再利用学到的梯度方向进行全局更新.
6.1.2 节中介绍的执行者 / 评论者算法是基于整个轨迹进行更新的.对于引入行为策略和重采样后,对于目标 的重采样系数变为 , 其中 .在这个表达式中,每个 都有比较大的方差,最终乘积得到的方差会特别大.一种限制方差的方法是控制重采样比例的范围,例如给定一个常数 ,将重采样比例截断为 .但是,如果直接将梯度方向中的重采样系数进行截断(例如从 修改为 ),会带来偏差.这时候我们可以再加一项来弥补这个偏差.利用恒等式 ,我们可以把梯度 拆成以下两项: 期望针对行为策略 ,此项方差是可控的; 即 采用针对原有目标策略 的期望后, 也是有界的 (即 ).
采用这样的拆分后,两项的方差都是可控的.但是,这两项中其中一项针对的是行为策略,另外一项针对的是原策略,这就要求在执行过程中兼顾这两种策略.
得到梯度方向后,我们希望对这个梯度方向做修正,以免超出范围.为此,用 散度增加了约束。记在迭代过程中策略参数的指数滑动平均值为 ,对应的平均策略为 .我 们可以希望迭代得到的新策略参数不要与这个平均策略 参数差别太大.所以,可以限定这两个策略在当前状态 下的动作分布不要差别太大.考虑到 KL 散度可以刻画两个分布直接的差别,所以可以限定新得到的梯度方向(记为 ) 与 的内积不要太大.值得一提的是, 实际上有和重采样比例类似的形式: 至此,我们可以得到一个确定新的梯度方向的优化问题.记新的梯度方向为 ,定义 我们一方面希望新的梯度方向 要和 尽量接近,另外一方面要满足 ,不超过一个给定的参数 .这样这个优化问题为 接下来求解这个优化问题.使用 Lagrange 乘子法,构造函数: 将前式代人后式可得 .由于 Lagrange 乘子应大等于 0 ,所以,Lagrange 乘子应为 ,优化问题的最优解为 这个方向才是我们真正要用的梯度方向.综合以上分析,我们可以得到带经验回放的执行者 / 评论者算法的一个简化版本.这个算法可以有一个回放因子,可以控制每次运行得到的经验可以回放多少次.算法 6-11 给出了经验回放的线程的算法.对于经验回放的线程所回放的经验是从其他线程已经执行过的线程生成并存储的,这个过程在算法 6-11 中没有展示,但是是这个算法必需的.在存储和回放的时候,不仅要存储和回放状态 , 动作 、奖励 等,还需要存储和回放在状态 产生动作 的概率 .有了这个概率值,才能计算重采样系数.在价值网络的设计方面,只维护动作价值网络.在需要状态价值的估计时,由动作价值网络计算得到.
算法 6-11: 带经验回放的执行者 / 评 论者算法 (异策简化版本)
参数: 学习率 ,指数滑动平均系数 ,重采样因子截断系数 ,折扣因子 ,控制回合数和回合内步数的参数.
- (同步全局参数)
- (经验回放)回放存储的经验轨迹 ,以及经验对应的行为策略概率
- 梯度估计
3.1 为梯度计算初始化:
- (初始化目标 )若 是终止状态,则 ;否则
- (初始化梯度) 3.2 (异步计算梯度) 对 ,执行以下内容:
- (估计目标)计算
- (估计价值梯度方向)
- (估计策略梯度方向)计算动作价值 ,重采样系数 以及
- (更新回溯目标)
- (同步更新)更新全局参数. 4.1 (价值更新) 4.2 (策略更新) 4.3 (更新平均策略)
TODO:柔性A-C
七. 连续动作空间的确定性策略
简单易懂:b站搬运ShuSenWang教程
如何理解:两个网络:策略网络与价值函数网络( 函数 ) , 时刻,先利用策略时序差分地更新价值函数,再更新策略网络,策略网络的梯度下降想法是:参数朝着使 函数增大的方向走,即 函数关于策略网络的参数求梯度,所以最后推得的关系式策略网络的更新式形式为连式法则的样子 .
本章介绍在连续动作空间里的确定性执行者 / 评论者算法.在连续的动作空间中,动作的个数是无穷大的.如果采用常规方法,需要计算 .而对于无穷多的动作,最大值往往很难求得.为此,D. Silver 等人在文章《 Deterministic Policy Gradient Algorithms 》中提出了确定性策略的方法,来处理连续动作空间情况.本章将针对连续动作空间,推导出确定性策略的策略梯度定理,并据此给出确定性执行者 / 评论者算法.
7.1 同策确定性算法
对于连续动作空间里的确定性策略, 并不是一个通常意义上的函数,它对策略参数 的梯度. 也不复存在.所以,第 6 章介绍的执行者 / 评论者算法就不再适用.幸运的是,曾提到确定性策略可以表示为 .这种表示可以绕过由于 并不是通常意义上的函数而带来的困难.
本节介绍在连续空间中的确定性策略梯度定理,并据此给出基本的同策确定性执行者 / 评论者算法.
7.1.1 策略梯度定理的确定性版本
当策略是一个连续动作空间上的确定性的策略 时,策略梯度定理为 (证明:状态价值和动作价值满足以下关系 以上两式对 求梯度,有 将 的表达式代人 的表达式中,有 对上式求关于 的期望,并考虑到 (其中 任取),有 这样就得到了从 到 的递推式.注意,最终关注的梯度值就是 所以有 就得到和之前梯度策略定理类似的形式 ).
对于连续动作空间中的确定性策略,更常使用的是另外一种形式: 其中的期望是针对折扣的状态分布 (discounted state distribution) 而言的。(证明: 得证.)
7.1.2 基本的同策确定性执行者 / 评论者算法
根据策略梯度定理的确定性版本,对于连续动作空间中的确定性执行者 / 评论者算法,梯度的方向变为 确定性的同策执行者 评论者算法还是用 来近似 .这时, 近似为 所以,与随机版本的同策确定性执行者 / 评论者算法相比,确定性同策执行者 / 评论者算法在更新策略参数 时试图减小 .迭代式可以是 算法 7-1 给出了基本的同策确定性执行者 / 评论者算法.对于同策的算法,必须进行探索.连续性动作空间的确定性算法将每个状态都映射到一个确定的动作上,需要在动作空间添加扰动实现探索.具体而言,在状态 下确定性策略 指定的动作为 ,则在同策算法中使用的动作可以具有 的形式,其中 是扰动量.在动作空间无界的情况下(即没有限制动作有最大值和最小值),常常假设扰动量 满足正态分布.在动作空间有界的情况下,可以用 clip 函数进一步限制加扰动后的范围(如 ,其中 和 是动作的最小取值和最大取值),或用 sigmoid 函数将对加扰动后的动作变换到合适的区间里 如 .
算法 7-1: 基本的同策确定性执行者 / 评论者算法
输入: 环境(无数学描述) 输出:最优策略的估计 参数:学习率 ,折扣因子 ,控制回合数和回合内步数的参数.
- (初始化) 任意值,任意值
- (带自益的策略更新)对每个回合执行以下操作:
2.1 (初始化累积折扣)
2.2 (初始化状态动作对)选择状态 ,对 加扰动进而确定动作 (如用正态分布随机变量扰动)
2.3 如果回合未结束,执行以下操作:
- (采样)根据状态 和动作 得到采样 和下一状态
- (执行)对 加扰动进而确定动作
- (估计回报)
- (更新价值)更新 以减小 如
- (策略改进)更新 以减小 如
- (更新累积折扣)
- (更新状态) .
在有些任务中,动作的效果经过低通滤波器处理后反映在系统中,而独立同分布的 Gaussian 噪声不能有效实现探索.例如,在某个任务中,动作的直接效果是改变一个质点的加速度.如果在这个任务中用独立同分布的 Gaussian 噪声叠加在动作上,那么对质点位置的整体效果是在没有噪声的位置附近移动.这样的探索就没有办法为质点的位置提供持续的偏移,使得质点到比较远的位置.在这类任务中,常常用 Ornstein Uhlenbeck 过 程作为动作噪声.Ornstein Uhlenbeck 过程是用下列随机微分方程定义的 (以一维的情况为例 ) 其中 是参数 是标准 Brownian 运动.当初始扰动是在原点的单点分布(即限定 ),并且 时,上述方程的解为 (证明:将 代入 化简可得 .将此式从 0 积到 ,得 .当 且 时化简可得结果).
这个解的均值为 0 ,方差为 ,协方差为 (证明:由于均值为 0 ,所以 .另外,Ito Isometry 告诉我们 ,所以 ,进一步化简可得结果.)
对于 总有 ,所以 .据此可知,使用 Ornstein Uhlenbeck 过程让相邻扰动正相关,进而让动作向相近的方向偏移.
7.2 异策确定性算法
对于连续的动作空间,我们希望能够找到一个确定性策略,使得每条轨迹的回报最大.同策确定性算法利用策略 生成轨迹,并在这些轨迹上求得回报的平均值,通过让平均回报最大,使得每条轨迹上的回报尽可能大.事实上,如果每条轨迹的回报都要最大,那么对于任意策略 采样得到的轨迹,我们都希望在这套轨迹上的平均回报最大.所以异策确定性策略算法引入确定性行为策略 ,将这个平均改为针对策略 采样得到的轨迹,得到异策确定性梯度为 这个表达式与同策的情形相比,期望运算针对的表达式相同.所以,异策确定性算法的迭代式与同策确定性算法的迭代式相同.
异策确定性算法可能比同策确定性算法性能好的原因在于,行为策略可能会促进探索,用行为策略采样得到的轨迹能够更加全面的探索轨迹空间.这时候,最大化对轨迹分布的平均期望时能够同时考虑到更多不同的轨迹,使得在整个轨迹空间上的所有轨迹的回报会更大.
7.2.1 基本的异策确定性执行者 / 评论者算法
基于上述分析,我们可以得到异策确定性执行者 / 评论者算法 (Off-Policy Deterministic Actor-Critic, ),见算法 7-2 .
值得一提的是,虽然异策算法和同策算法有相同形式的迭代式,但是在算法结构上并不完全相同.在同策算法迭代更新时,目标策略的动作可以在运行过程中直接得到;但是在异策算法迭代更新策略参数时,对环境使用的是行为策略决定的动作,而不是目标策略决定的动作,所以需要额外计算目标策略的动作.在更新价值函数时,采用的是 学习,依然需要计算目标策略的动作.
算法 9-2: 基本的异策确定性执行者 / 评论者算法
输入:环境(无数学描述) 输出:最优策略的估计 参数:学习率 ,折扣因子 ,控制回合数和回合内步数的参数.
- (初始化) 任意值, 任意值.
- (带自益的策略更新)对每个回合执行以下操作:
2.1 (初始化累积折扣)
2.2 (初始化状态)选择状态
2.3 如果回合未结束,执行以下操作:
- (执行)用 得到动作
- (采样)根据状态 和动作 得到采样 和下一状态
- (估计回报)
- (更新价值)更新 以减小 如
- (策略改进)更新 以减小 (如
- (更新累积折扣)
- (更新状态) .
7.2.2 深度确定性策略梯度算法
深度确定性策略梯度算法(Deep Deterministic Policy Gradient, )将基本的异策 确定性执行者 / 评论者算法和深度 Q 网络中常用的技术结合.具体而言,确定性深度策略梯度算法用到了以下技术.
- 经验回放:执行者得到的经验 收集后放在一个存储空间中,等更新参数时批量回放,用批处理更新.
- 目标网络:在常规价值参数 和策略参数 外再使用一套用于估计目标的目标价值参数 和目标策略参数 在更新目标网络时,为了避免参数更新过快,还引入了目标网络的学习率
算法 7-3 给出了深度确定性策略梯度算法.
算法 7-3: 深度确定性策略梯度算法 (假设 总是在动作空间内)
输入:环境(无数学描述) 输出:最优策略的估计 参数:学习率 ,折扣因子 ,控制回合数和回合内步数的参数,目标网络学习率
- (初始化) 任意值, 任意值,
- 循环执行以下操作:
2.1 (累积经验)从起始状态 出发,执行以下操作,直到满足终止条件:
- 用对 加扰动进而确定动作 (如用正态分布随机变量扰动)
- 执行动作 ,观测到收益 和下一状态
- 将经验 存储在经验存储空间 2.2 (更新)在更新的时机,执行一次或多次以下更新操作:
- (回放)从存储空间 采样出一批经验
- (估计回报)为经验估计回报
- (价值更新)更新 以减小
- (策略更新)更新 以减小 (如
- (更新目标)在恰当的时机更新目标网络和目标策略, .
7.2.3 双重延迟深度确定性策略梯度算法
S. Fujimoto 等人在文章 《 Addressing function approximation error in actor-critic methods》 中给出了双重延迟深度确定性策略梯度算法(Twin Delay Deep Deterministic Policy Gradient, TD3 ),结合了深度确定性策略梯度算法和双重 Q 学习.
回顾前文,双重 学习可以消除最大偏差.基于查找表的双重 Q 学习用了两套动作价值函数 和 ,其中一套动作价值函数用来计算最优动作(如 ,另外一套价值函数用来估计回报(如 ;双重 网络则考虑到有了目标网络后已经有了两套价值函数的参数 和 所以用其中一套参数 计算最优动作(如 ),再用目标网络的参数 估计目标 (如 ). 但是对于确定性策略梯度算法,动作已经由含参策略 决定了 (如 ),双重网络则要由双重延迟深度确定性策略梯度算法维护两份学习过程的价值网络参数 和目标网络参数 .在估计目标时,选取两个目标网络得到的结果中较小的那个,即 .
算法 7-4 给出了双重延迟深度确定性策略梯度算法.
算法 7-4 :双重延迟深度确定性策略梯度
输入: 环境(无数学描述) 榆出:最优策略的估计 参数:学习率 ,折扣因子 ,控制回合数和回合内步数的参数,目标网络学习率 .
- (初始化) 任意值, 任意值,
- 循环执行以下操作:
2.1 (累积经验)从起始状态 出发,执行以下操作,直到满足终止条件:
- 用对 加扰动进而确定动作 (如用正态分布随机变量扰动)
- 执行动作 ,观测到收益 和下一状态
- 将经验 存储在经验存储空间 2.2 (更新)一次或多次执行以下操作:
- (回放)从存储空间 采样出一批经验
- (扰动动作)为目标动作 加受限的扰动,得到动作
- (估计回报)为经验估计回报
- (价值更新)更新 以减小
- (策略更新)在恰当的时机,更新 以减小 如
- (更新目标)在恰当的时机,更新目标网络和目标策略, .
TODO:AlphaZero算法
RL 在金融中的应用
参见Modern Perspectivs on RL in Finance和RL in economics and finance 2021.
本来应该通过动态规划方法解这些问题.用动态规划解优化问题通常需要下述三个条件:
- 明确知道模型的状态转移概率
- 有足够的算力来求解DP
- Markov性质
RL,结合了DP,蒙特卡洛模拟,函数近似和机器学习.
RL在金融中主要有以下三个应用方向:
- 衍生品定价/对冲
- 投资组合/资产配置
- 做市
一. RL for Risk Management
通常而言,学术中对衍生品定价和对冲都是基于随机环境下的有模型决策(model-driven decision rules in a stochastic environment),常规对冲策略都会用到希腊值 Greeks,代表模型对不同参数风险定价的敏感程度.
这种方法在高维情况时通常缺少有效的数值模拟方法.
Deep Hedging
参考Deep Hedging, Buehler et al.
市场摩擦(market frictions):指金融资产在交易中存在的难度,如手续费(transaction costs)、买卖价差(bid/ask spread)、流动性约束(liquidity constraints)等.
本文中的对冲的对象是对冲掉一些衍生品的投资组合.
把 trading decision 建模成一个网络,特征不仅仅有价格,还有交易信号,新闻分析(news analytics),过去对冲决策等等.
算法是完全 model-free,不依赖对应市场的动力.我们只需确定下来市场的状态生成(scenario generator),损失函数,市场摩擦和交易行为(trading instruments).所以此方法 lends itself to a statistically driven market dynamics,我们不需要像传统方法那样计算单个衍生品的希腊值,应将建模的精力花在实现真实的市场动力和样本外表现.
建模:带市场摩擦的离散市场
考虑有限时域的离散金融市场 和交易时刻 . 固定一个有限概率空间 和一个概率测度 s.t. 对所有的 . 定义所有 上的实值随机变量 .
记 在 available 的新市场数据, including market costs and mid-prices of liquid instruments-typically quoted in auxiliary terms such as implied volatilities-news, balance sheet information, any trading signals, risk limits etc. 过程 生成域流 , i.e. 表示到 时刻所有可用的信息. 注意到每个 可测的随机变量可以写成 的函数.
市场有 个hedging instruments with mid-prices 取值于 -valued -adapted 随机过程 . 即可以用来做对冲的资产.
衍生品的投资组合即负债(liability)是一个 可测的随机变量 . 到期日 是所有衍生品种最大的那一个.此为想要对冲掉的东西.
即 用.
想要在 对冲掉 , 我们要用 -valued -adapted stochastic process with 来交易 . 表示智能体在 时刻对第 个资产的持有. 同样定义 .
记 是这样的交易策略的无约束集合. 但是每个 有其交易约束. 可能来自 liquidity, asset availability or trading restrictions. They are also used to restrict trading in a particular option prior to its availability. In the example above of an option which is listed in , the respective trading constraints would be until the point. 所以我们假定 受 约束,由一个连续 可测映射给出 , i.e.
对无约束策略 , 我们定义它在 的有约束投影 .记 为受约束的交易策略相应的非空集.
EXAMPLE 1 Assume that are a range of options and that computes the Black-Scholes Vega of each option using the various market parameters available at time . The overall Vega traded with is then A liquidity limit of a maximum tradable Vega of could then be implemented by the map:
对冲
交易是自融资的.不带交易费用的 时刻最终财富为 , 其中 当考虑交易费用时,在 时刻买入 的 股票a产生费用 . 策略总的交易费用为 回忆 .所以智能体总的花费变为
DRL
期权
期权基础知识
一.基础概念
1.1 BS公式(Delta对冲下)
BS部分参考知乎专栏:Black-Scholes 模型学习框架
假设股价满足 自融资资产组合(此处为期权价格的一个复制)价值变化为 对 ( 时刻衍生品的价值)做Ito公式,比较项的系数得到Delta 对冲下 的 BS 偏微分方程 (在比较系数过程中会使用到替换为,即为持有的债券份额为自融资总份额减去持有的标的份额,标的的持有量已经使用为) 终值条件(看涨期权)为 解即为 BS 公式 注意到 的漂移项 对期权定价没有影响.
BS公式解法(欧式call)
有考虑边界条件
注意到这是一个 Cauchy-Euler 方程,能通过下述变量代换将其转化为一个扩散方程 The solution of the PDE gives the value of the option at any earlier time, Black-Scholes PDE 变为一个扩散方程: 终值条件 现在变为了初值条件 其中 是 Heaviside 阶梯函数,
使用解给定了初值函数 的扩散方程的标准卷积法,得到 经过处理,得到: 其中 N 为正态分布累计密度函数,
BS model给出了期权价格的函数,作为一个波动率的函数.可以通过这个公式在给定期权价格时计算隐含波动率(implied volatility).但事实是 BS 波动率强烈依赖于欧式期权的到期日和行权价.
波动率微笑是指给定到期日下,隐含波动率与行权价(maturity)的关系.
1.2 Delta对冲
我们现在考虑用衍生品 和其标的资产 构建一个“无风险组合”,考虑这样的自融资组合 ,即每一单位的空头衍生品,我们用 单位多头的股票对其进行对冲 (Hedging),由于其自融资的特性,根据定义,我们有 ,将股价的 SDE 和上一节中通过伊藤-德布林公式求出的 带入这个式子,我们可以得到: 因为要使资产组合为无风险的, >一个自融资组合 如果是无风险的,则可以表示为 ,且 . 1式即 Delta 对冲法则,将 带入2式我们再次得到 BlackScholes 偏微分方程: .
1.3 BS公式(风险中性定价下)
1.3.1 鞅
定义 在 上的随机过程 , ,称其是关于域流 的鞅,如果满足:
- 是 -适应的 (adapted);
- ;
- 对于 ,有 .
1.3.2 Radon-Nikodym导数
定义 设 和 为 上的等价测度,若 ,a.s.,且有 , 则称 是 关于 的 Radon-Nikodym 导数,记作: .
,即 ,其中 表示在测度 下的期望.进一步的,可以用条件期望定义R-N导数过程: .用鞅和RN导数过程的定义,可以简单的证明,R-N导数过程 是一个 -鞅.
1.3.3 资产的现值
用 表示风险资产的价值过程.首先要知道,在 模型的假设下,市场是完备 (Complete) 的,即任意资产 都可以被风险资产 和无风险资产 构成的组合所复制,即对任意一个 ,我们可以把它表示 为一个自融资组合: 可以看到该组合的收益率部分由组合的时间价值 与风险资产的超额收益 构成.我们考虑该资产的折现价值过程
1.3.4 鞅表示
定理(鞅表示) 设 是 上的布朗运动,而 为 -鞅, 且满足 ,则存在一个 适应的过程 ,使得 .
可以看到,如果 是鞅,那么 可以被表示为一个伊藤积分的形式,即没有 项而仅仅只有 项.再看我们的折现价值过程 ,如果想让它只有 项从而变成一个鞅,我们貌似只需要做变换: ,这样折现价值过程就可以被表示为:
但是,鞅表示定理有个非常非常重要的前提,就是你需要保证 这玩意儿是个伊藤积分,即 需要是一个布朗运动. 我们知道是布朗运动,但是经过这样变换过后的 还是布朗运动么,或者说我们需要如何选择新的测度, 来保证经过变换之后的 仍然是个布朗运动?
1.3.5 Girsanov定理
定理(Grisanov) 设 是 上的布朗运动. 为一个相适应的过程, 定义指数鞅过程,.其中 是初值 的相适应的过程, 表示二次变差.则可以定义新的概率测度 .如果在概率测度 下 是一个布朗运动,那么: 在新的概率测度 下也是一个布朗运动.
这样一来, 我们就找到了新的测度和两个测度之下布朗运动之间的关系.我们看新定义的这个布朗运动:,它的实质是把资产的风险溢价项给消除了.风险溢价是什么?是对承担单位风险的补偿,在新的测度下风险溢价是没有补偿的,所以说在这个世界里,风险是中性的,因此我们把这样定义的新测度 称为风险中性测度,并且用 来表示.
1.3.6 风险定价公式
现在我们知道了变换公式 ,那么在风险中性测度 下,风险资产 所满足的 SDE 也需要进行相应的变化: 由此可见,在风险中性世界里,风险资产 (例如股票) 的收益率完全等于无风险收益率.
此时任意资产的折现价值过程可以被表示为: .我们知道 在 下是一个鞅,那么由鞅的性贡我们可以知道: .常利率假设下有: .
假设我们需要对一个欧式看涨期权进行定价,我们知道该期权在到期日 的价值为 ,则有: 其中: 与PDE方法一致.
我们来总结一下 Risk-neutral Pricing 的几个步骤:
- 找到资产的折现价值过程;
- 作测度变换令这个折现价值在新的测度下为鞅;
- 用 Girsanov 定理找到新的变换;
- 利用鞅性质得到风险中性定价公式。
二.波动率价差
2.1 各种形式
2.1.1 跨式期权
跨式期权(straddle)由一个看涨期权和一个看跌期权组成,这两个期权具有相同的行权价格和到期日.在跨式期权中,这两个期权要么同时买入(跨式期权多头),要么同时卖出(跨式期权空头).
2.1.2 宽跨式期权
与跨式期权一样,宽跨式期权(straggle)由一个看涨期权和一个看跌期权组成,且两个期权的到期时间相同.但在宽跨式期权中,两个期权的行权价格不同.
为了避免混淆,通常假设宽跨式期权由虚值期权组成.如果当前标的市场价格为 100,而交易者想买入 3 月行权价格为 90/110 的宽跨式期权,这意味着他想买入 1 份 3 月行权价格为 90 的看跌期权和 1 份 3 月行权价格为 110 的看涨期权.
2.1.3 蝶式期权
蝶式期权(butterfly)通常就是一个由相同类型(要么都是看涨,要么都看跌)并具有相同到期时间,且合约间行权价格间距相等的 3 份期货合约组成的三腿价差.蝶式期权多头中,买入外部行权价格的期权合约,卖出内部行权价格的期权合约.构成比例固定不变:都为 1x2x1 .
为何买外卖内算作蝶式的多头? 根据损益图,如果不考虑期权费,买外卖内的损益总不小于零,故必须付出一定金额,所以称为多头.
跨式期权潜在收益或风险都是无限的,而蝶式都是有限的.
2.1.4 鹰式期权
鹰式期权(condor)由 4 份期权组成,2 个内部行权价格和两个外部行权价格.构成比例总是 1x1x1x1 ,尽管两个内部行权价格的差额可以变化,但是 2 个最低行权价格的差额一定要与 2 个最高行权价格的差额相等(why?).与蝶式期权一样,鹰式期权中所有期权的到期时间和类型都相同.买入两个外部行权价格的期权,卖出两个内部行权价格的期权就构成了鹰式期权多头.
上述四个策略对标的市场的变动方向没有偏好,损益图为对称的.
2.1.5 比例价差
在波动率价差中,交易者不能完全不关心标的市场的变动方向.交易者可能认为向一个方向变动的可能性要大于向另一个方向变动的可能性.鉴于这个原因,交易者可能希望构建一个当标的向一个方向而不是另一个方向变动时能最大化收益或最小化损失的价差策略.为了实现这个目标,交易者可以构建一个比例价差(ratio spread)——买入并卖出不同数量的期权,所有期权都是同一类型的,且具有相同的到期时间.和其他波动率头寸一样,比例价差也是典型的 Delta 中性策略.
三.希腊值的含义
序:各种希腊值特性
delta
call 的价值变化:标的相对于行权价的变化

put 的价值变化:标的相对于行权价的变化

delta 随标的的变化

call_delta 随 volatility 的变化

put_delta 随 volatility 的变化

call_delta 随到期时间变化

put_delta 随到期时间变化

call_delta 随着时间推移或者波动率下降的变化

Vanna
Vanna:作为 Delta 对波动率的偏导,或者 Vega 对标的价格的偏导.

theta
theta:期权价格随着标的价格变化,此处取了绝对值(call 与 put 一样,都是负的!跟恒正 GAMMA 比较)

vega

gamma
恒正的 GAMMA:


其中
3.0 随机分析几大基础定理
3.0.1 Radon–Nikodym
3.0.1.1 Radon-Nikodym 定理
测度空间 上定义有两个 -有限测度: 和 .定理表明:如果 (i.e. 关于 绝对连续),则存在一个 -可测函数 s.t. 可测集,
-有限: 是一个测度空间, 是上面一测度. 称为其上一个 -有限测度若 可以写成至多可列个有限测度集合的无交并.
绝对连续:实线上 Borel 子集上的测度 称为关于 Lebesgue 测度 绝对连续若对任何 的可测集 有 .记为 .
等价测度: 和 称为等价测度若 且 .
3.0.1.2 Radon-Nikodym 导数
上述函数 在几乎处处意义下唯一,通常写为 ,通常称为 Radon-Nikodym 导数.
3.0.2 Girsanov 定理
二次变差(quadratic variation):,,计算公式为 .
是 Wiener 概率空间 上的 Wiener 过程.令 是适应于 Wiener 过程生成的自然域流 的可测过程,.
定义 关于 的 Doléans-Dade exponential 如下: 其中 是 的二次变差.若 是严格正的鞅,那么可以定义 上的概率测度 s.t. 有 Radon-Nikodym 导数: 则对每个 , 限制在未扩充的 -域 上和 限制在 是等价的.进一步,若 是 下的局部鞅,那么过程 是 下的在 上的局部鞅.
推论:若 是连续过程, 是 下的布朗运动,那么: 是 下的布朗运动.
3.1 BS公式含义
Moneyness 指的是标的现价相对于行权价格的关系.下面考虑的是远期 F,
进行标准化后为:
下面我们记 ,故
标准的 moneyness 为如下均值: 大小关系为: 每级相差 ,这几项都在单位标准差内,所以把这几项转换成百分数,用标准正态的累计密度函数来评估.对这几个量的解释很精细(subtle),与风险中性测度有关.简单来说,有如下解释:
- 是二项 call option 的未来价值,或者风险中性下期权会在价内行权的可能性,with numéraire cash(风险中性资产)
- 是标准化的货币价值的百分比(概率?)
- 是 Delta,或者风险中性下期权会在价内行权的可能性,with numéraire asset(注意与上面的不同之处,cash 与 asset,债券与标的资产)
These have the same ordering, asis monotonic (since it is a CDF): 因为 是单调的(是一个CDF),他们有大小关系
3.2 计价单位的变换
wiki 概述:
在证券交易的金融市场中,计价单位的变换可以用来对资产定价.例如,若 exp 是 时刻投资在货币市场的 元在 时刻的价格,那么以货币市场定价的所有资产(记作 )在风险测度(记作 )下是鞅: 现在假定 是另一个严格正的交易资产(因此在货币市场的定价下是一个鞅),那么我们能根据 Radon-Nikodym 导数定义一个新的概率测度: 根据贝叶斯定理可以证明 关于新的计价单位 在 下是一个鞅:
知乎总结:
3.2.1 概述
在探讨计价单位变换之前,我们先来粗略的看一下这个公式长什么样子: ,这里 是一个计价单位 (Numéraire).乍一看,这个式子和风险中性定价公式 长得一模一样,只不过是将债券 替换为了另一个东西 ,然后从 测度下的期望变成了在 下的期望.
直观上来理解,就是这个意思,风险中性定价公式其实是以债券为计价单位,从而得出的资产的期望价值,那在某些情况下,我们也可以用其他资产作为计价单位,来对某些资产进行定价,或者是进行计算上的简化,这就是计价单位变换的动机所在.
其实可以看到,计价单位变换的本质,就是从一个 测度转化为另一个 测度,所以这个公式的核心,是找到连接两个测度的 Radon-Nikodym 导数 .
3.2.2 计价单位变换公式的推导
3.2.2.1 RN 导数的存在
这里我们想到的第一个问题是,对于任意一个给出的计价单位 ,存在这样的 RN 导数来定义测度 么?
这里我们需要注意的是,计价单位变换公式中,要求 是任意一个价格严格为正的资产,由先前的知识可以知道,这样的资产以债券计价时 (或者说以货币账户计价时) 是一个鞅,即 在测度 下是一个鞅,这样良好的性质,保证了 RN 导数 的存在性. >定理( 的存在性). 设测度 与 为 上的等价摡率测度,则存在 ,满足 ,且 .
根据这个定理,我们可以找到那个存在的 .可以验证,这样定义的 严格为正,且有 ,满足上述定理: 定义 RN 导数过程 ,可以证明 在原本的测度 下是一个 -鞅: (鞅性质).
3.2.2.2 RN 导数的性质
这里设 是由上一节中的定义 给出.
则在两个测度下期望的运算之间,可以用 RN 导数过程来联系.
设 为 -可测的随机变量.给定 ,则无条件期望之间的关系: 给出简单的证明: 更进一步的,给定 ,则条件期望之间的关系: 这里的其实是贝叶斯定理 (Abstract Bayes' Theorem) 的结论.
定理 (Abstract Bayes). 设 为 上的随机变量, 是 上由 导数 定义的测度.设 为 -代数且 ,则有: .
3.2.2.3 计价单位变换公式
首先我们有风险中性定价公式 ,接着根据定义的 RN 导数 ,可以得到:
对于 conditional on 的公式,根据 Abstract Bayes' Theorem 有:
3.2.2.4 新测度下的股价
根据这个计价单位变换公式,在实际运用时我们首先需要找到合适的计价单位 ,然后将风险中性定价公式中的 替换为 ,接着求一个在测度 下的期望 就可以了.
但要求出 ,我们还需要知道某个随机变量在 下的分布,或者说需要知道它新的 dynamics 是什么,此时还需要 Girsanov 定理来帮助我们找到 下的布朗运动 和 下的布朗运动之间的关系 .
定理 (Grisanov). 设 是 上的布朗运动, 为一个相适应的过程,定义指数鞅过程:,其中 是初值 的相适应的过程, 表示二次变差.则可以定义新的概率测度 .如果在概率测度 下 是一个布朗运动,那么: 在新的概率测度 下也是一个布朗运动.
我们这里以股价 为计价单位来运用 Girsanov 定理,找到 与 之间的关系.根据 的定义,有: 根据 Girsanov 定理,可以得到 ,于是在 下,对数价格 的 SDE 为:
3.2.3 一些栗子
根据风险中性公式, 0 时刻欧式看涨期权的价值应为: 此时,如果我们不想计算一个比较复杂的期望,则可以用 作为计价单位处理第一项,得到: 可以看到,我们其实是将原式转化为求事件 分别在 和 下的概率.
由上一节可以知道,在 下有 ,则:
此时可以很快的得到欧式看矤期权的定价公式:
上述推导也说明了, 表示在 下事件 的概率,即代表了在风险中性世界 中,该看涨期权在到期日被执行的概率,而 表示在 下事件 的概率,即 在以股价为计价单位的世界中,该看涨期权在到期日被执行的概率.
期权模型
随机波动率摘要
一. BS公式(Delta对冲下
BS部分参考知乎专栏:Black-Scholes 模型学习框架
假设股价满足 自融资资产组合(此处为期权价格的一个复制)价值变化为 对 ( 时刻衍生品的价值)做Ito公式,比较项的系数得到Delta 对冲下 的 BS 偏微分方程 (在比较系数过程中会使用到替换为,即为持有的债券份额为自融资总份额减去持有的标的份额,标的的持有量已经使用为) 终值条件(看涨期权)为 解即为 BS 公式 注意到 的漂移项 对期权定价没有影响.
BS公式解法(欧式call)
有考虑边界条件
注意到这是一个 Cauchy-Euler 方程,能通过下述变量代换将其转化为一个扩散方程 The solution of the PDE gives the value of the option at any earlier time, Black-Scholes PDE 变为一个扩散方程: 终值条件 现在变为了初值条件 其中 是 Heaviside 阶梯函数,
使用解给定了初值函数 的扩散方程的标准卷积法,得到 经过处理,得到: 其中 N 为正态分布累计密度函数,
BS model给出了期权价格的函数,作为一个波动率的函数.可以通过这个公式在给定期权价格时计算隐含波动率(implied volatility).但事实是 BS 波动率强烈依赖于欧式期权的到期日和行权价.
波动率微笑是指给定到期日下,隐含波动率与行权价(maturity)的关系.
二. Delta对冲
我们现在考虑用衍生品 和其标的资产 构建一个“无风险组合”,考虑这样的自融资组合 ,即每一单位的空头衍生品,我们用 单位多头的股票对其进行对冲 (Hedging),由于其自融资的特性,根据定义,我们有 ,将股价的 SDE 和上一节中通过伊藤-德布林公式求出的 带入这个式子,我们可以得到:
因为要使资产组合为无风险的,
>一个自融资组合 如果是无风险的,则可以表示为 ,且 .
1式即 Delta 对冲法则,将 带入2式我们再次得到 BlackScholes 偏微分方程: .
三. BS公式(风险中性定价下)
3.1 鞅
定义 在 上的随机过程 , ,称其是关于域流 的鞅,如果满足:
- 是 -适应的 (adapted);
- ;
- 对于 ,有 .
3.2 Radon-Nikodym导数
定义 设 和 为 上的等价测度,若 ,a.s.,且有 , 则称 是 关于 的 Radon-Nikodym 导数,记作: .
,即 ,其中 表示在测度 下的期望.进一步的,可以用条件期望定义R-N导数过程: .用鞅和RN导数过程的定义,可以简单的证明,R-N导数过程 是一个 -鞅.
3.3 资产的现值
用 表示风险资产的价值过程.首先要知道,在 模型的假设下,市场是完备 (Complete) 的,即任意资产 都可以被风险资产 和无风险资产 构成的组合所复制,即对任意一个 ,我们可以把它表示 为一个自融资组合: 可以看到该组合的收益率部分由组合的时间价值 与风险资产的超额收益 构成.我们考虑该资产的折现价值过程
3.4 鞅表示
定理(鞅表示) 设 是 上的布朗运动,而 为 -鞅, 且满足 ,则存在一个 适应的过程 ,使得 .
可以看到,如果 是鞅,那么 可以被表示为一个伊藤积分的形式,即没有 项而仅仅只有 项.再看我们的折现价值过程 ,如果想让它只有 项从而变成一个鞅,我们貌似只需要做变换: ,这样折现价值过程就可以被表示为: 但是,鞅表示定理有个非常非常重要的前提,就是你需要保证 这玩意儿是个伊藤积分,即 需要是一个布朗运动. 我们知道是布朗运动,但是经过这样变换过后的 还是布朗运动么,或者说我们需要如何选择新的测度, 来保证经过变换之后的 仍然是个布朗运动?
3.5 Girsanov定理
定理(Grisanov) 设 是 上的布朗运动 为一个相适应的过程, 定义指数鞅过程, 其中 是初值 的相适应的过程, 表示二次变差.则可以定义新的概率测度 .如果在概率测度 下 是一个布朗运动,那么: 在新的概率测度 下也是一个布朗运动.
这样一来, 我们就找到了新的测度和两个测度之下布朗运动之间的关系.我们看新定义的这个布朗运动:,它的实质是把资产的风险溢价项给消除了.风险溢价是什么?是对承担单位风险的补偿,在新的测度下风险溢价是没有补偿的,所以说在这个世界里,风险是中性的,因此我们把这样定义的新测度 称为风险中性测度,并且用 来表示.
3.6 风险定价公式
现在我们知道了变换公式 ,那么在风险中性测度 下,风险资产 所满足的 SDE 也需要进行相应的变化: 由此可见,在风险中性世界里,风险资产 (例如股票) 的收益率完全等于无风险收益率.
此时任意资产的折现价值过程可以被表示为: .我们知道 在 下是一个鞅,那么由鞅的性贡我们可以知道: . 常利率假设下有: .
假设我们需要对一个欧式看涨期权进行定价,我们知道该期权在到期日 的价值为 ,则有: 其中: 与PDE方法一致.
我们来总结一下 Risk-neutral Pricing 的几个步骤:
- 找到资产的折现价值过程;
- 作测度变换令这个折现价值在新的测度下为鞅;
- 用 Girsanov 定理找到新的变换;
- 利用鞅性质得到风险中性定价公式。
四. SDE WITH ANN
定价模型很重要的一点是能快速地根据现有或者历史价格校准模型.
四. 局部波动率
4.1 Dupire 的工作1994
局部波动率基于如下: 其中, 是 和 的确定性函数, 是固定的参数.在风险中性测度时, 下. 一旦 给定,那么模型也就定了.
Dupire(1994)证明了当给定了关于 (行权价) 和 (到期日) 的期权价格函数 时,局部波动率 是唯一确定的.
令 是如下定义的 的转移密度函数(transitional density function): >转移密度函数: at at ,指在 时刻 条件下在 时刻时 的分布 其中 代表风险中性测度.众所周知 满足倒向的 Fokker-Planck 方程: 其中 一维情形,Fokker-Planck 方程有两个参数,一是拓扑参数 ,另一是扩散
又可证它也满足前向的 Fokker-Planck 方程 其中 且 下面可推导 Dupire 方程,看涨期权的价格满足 其中 是 的风险中性测度.对(4.1)关于 做一次和二次微分,有
已知 满足前向 Fokker-Planck 方程 (4.1)对 微分,有 这里我们假设 与 无关,故我们得到了 Dupire 方程: Dupire 方程最大的优点是把局部波动率函数用期权价格和他们的微分表示出来了 以上结果扩展到了时间依赖的利率,我们只要将 替换为 即可.
局限性
- 可用的观测值是很少的
- 数值微分不可靠,二阶的更甚
确定方程的传统做法是二元样条插值,再对模型进行校准(calibration).
局部波动率作为瞬时方差的条件期望
考虑如下形式的一般的随机波动率模型 其中 远期价格 资产价格过程变为 考虑到 , 欧式看涨期权的 -forward 价格为 上式两端对 求微分 >其中 是 Heaviside函数, 是Dirac函数, , .
对最终的支付应用 Ito 公式并令 有 两侧取条件期望 ,有 上式用到了 鞅的性质.接下来 第二个等式是条件期望公式的一个推广: 或者 因为 故有 进行比较,可知对应的局部波动率模型为 也就是说,局部方差是以最终股票价格 等于行权价 为条件的瞬时方差的风险中性期望.
五. 随机局部波动率模型
5.1 Jex的随机局部波动率模型1999
其中 和 应该是无风险利率减股息收益, 是随机波动率部分, 是波动率的均值回复的平衡点. >Heston随机波动率模型:,.或者()
注意到如果没有 这一项,模型即为Heston模型,而当 为 时模型即为Dupire.
5.2 GAN基于LOCAL_STOCHASTIC_VOLATILITY
This means parameterizing the model pool in a way which is accessible for machine learning techniques and interpreting the inverse problem as a training task of a generative network, whose quality is assessed by anadversary.We pursue this approach in the presentarticle and use as generative models so-called neural stochastic differential equations (SDE),which just means to parameterize the drift and volatility of an Itˆo-SDE by neural networks.
文中指的neural SDE即通过神经网络来对Ito-SDE的漂移项和波动率进行参数化.
这里考虑的某资产的折现后价格过程(discounted price process) :
其中 是某个 中取值的随机过程, 称为杠杆函数(Leverage function)取决于 和资产当前价格.
的选取非常重要,需要很好地校准市场上观测到的隐含波动率.故 需要满足如下条件:
其中 指 Dupire 的local volatility function.注意到(1.1)是 的隐式方程,因为 中需要 .故此时 满足的SDE也成为了一个McKean-Vlasov SDE.
本文采用了 an alternative,fully data-driven 方法,规避了其他计算 Dupire 局部波动率的方法中必须的对波动率曲面插值的做法,即此方法只需离散数据.
令 , 为不同期权的到期日.使用神经网络族 将杠杆函数参数化,参数为 ,i.e.
于是有了neural SDE的生成模型组(generative model class),即使用带参数 的神经网络来参数化漂移项 和波动率项 ,i.e.
本文中,没有漂移项,波动率项如下所示:
依次对每个到期日,参数优化采用如下的校准法则: 其中 是期权的数目, 和 是模型与市场分别的价格.
对固定的 , 是非线性非负凸函数满足 且 对 ,衡量模型和市场价的距离. 某种权重,参数 扮演了对抗(adversarial)的部分,注意到 和 都受 控制.
Rough Volatility
Bergomi's model revisited
Variance swap
A variance swap with maturity is a contract which pays out the realized variance of the logarithmic total returns up to less a strike called the variance swap rate , determined in such a way that the contract has zero value today.
The annualized realized variance of a stock price process for the period with business days is usually defined as The constant denotes the number of trading days per year and is usually fixed to 252 so that . We assume the market is arbitrage-free and prices of traded instruments are represented as conditional expectations with respect to an equivalent pricing measure .
A standard result gives that as , we have
when is a continuous semimartingale.
Approximating the realized variance by the quadratic variation of the log returns works very well for variance swaps, but care should be taken in practise if we price short dated non-linear payoffs on realized variance. Denote by , the price at time of a variance swap with maturity . It is given under by
We define the forward variance curve as Note that, if we assume that the S&PX index follows a diffusion process, with a general stochastic volatility process, , the forward variance is given by It can be seen as the forward instantaneous variance for date , observed at . In particular
The current price of a variance swap, , is given in terms of the forward variances as The models used in practice are based on diffusion dynamics where forward variance curves are given as a functional of a finite-dimensional Markov-process: where the function and the m-dimensional Markov-process satisfy some consistency condition, which essentially ensures that for every fixed maturity , the forward variance is a martingale.
※ Pricing under rough volatility ※
ATM volatility skew
其中 是离到期日的时间, 是log-strike.在传统随机波动率模型中, 对短期时间是常数,对长时间与 成反比.经验上观测到 对某些 与 成比例.
forward variance curve
表示 时刻瞬时方差.则 forward variance curve 为:
Wick exponential
对零均值的 Gaussian R.V. ,其 Wick exponential 为
这里只作记号使用,不涉及其运算.
模型推导
Gatheral et al. (2014) 发现已实现方差(realized variance) 与如下模型一致
其中 是 fBm.This relationship was found to hold for all 21 equity indices in the Oxford-Man database, Bund futures, Crude Oil futures, and Gold futures. Perhaps this feature of the time series of volatility is universal?
考虑 fBm 的 Mandelbrot-Van Ness 表示
其中 , 这样选取是为了保证
将 (2)带入(1)并由 ,可以得到 基于 physical measure 的变化:
注意到 是 -可测,而 与 独立且是 Gaussian with mean zero and variance . 用如下记号: 和 有相同分布,仅仅方差变为 .记 则有 且 结合 Wick expenential 这里,由 1 式可知 依赖 的整个历史,所以 是 non-Markovian.而 2 式表示 the conditional distribution of depends on only through the instantaneous variance forecasts
总结,得到如下模型基于实际概率测度 : 其中,两个布朗运动 和 相关系数为 .
Pricing under Q
期权在 t 时刻的定价基于等价鞅测度 on s.t. 资产价格过程 在 下是一个鞅.
在固定的时间域 中,通过 Girsanov 变换, 使得
另一方面 由 而来,而 是一个布朗运动与 以如下关系相关, 其中 是一对独立的标准布朗运动.对第二项的一个标准的测度变换为 其中 ,for ,是一个合适的适应过程,称为波动率风险的市场价格.所以有 将其重写为 由 4 ,在 下, 特别的, 适应于由 生成的域流(和由 生成的域流一致).把上式重写为 指数中的最后一项明显改变了 的边缘分布.虽然 在 下的条件分布是对数正态的,它在 下不是对数正态.
rBergomi model
考虑最简单的测度变换, assuming for simplicity, resp. as a first approximation, 是关于 的确定性函数.则由 6 我们有 其中 .进一步有 forward variance curve
是如下两项的乘积: 依赖于驱动布朗运动的历史;另一项依赖于风险价格 .
模型 7 是 non-Markovian 因为 .
※on deep calibration of rough sv model※
一.介绍
从隐含波动率按 moneyness 和 maturity 的变化可以观察到存在着著名的 smiles 和 at-the-money(ATM) skews 现象,与 BS 公式相悖.特别的,Bayer, Friz, and Gatheral 经验性地表明 ATM skew 符合如下形式:
其中 为 moneynessand , 为 time to maturity .
根据 Gatheral ,扩散的随机波动率模型不能复现当 time to maturity 趋于零时 volatility skew 的幂指数爆炸现象,反而表现为常数现象.
RSV 可定义为一族连续路径的随机波动率模型,其瞬时波动率由一个 Holder 正则性比布兰运动小的随机过程驱动,通常刻画为 Hurst 系数 H<1/2 的分形布朗运动.
这种范式转变的证据现在是 overwhelming ,一方面在物理测度下,时间序列分析表明对数已实现波动率的 Holder 正则为 0.1 阶;另一方面,在定价测度下经验性观察也表明在零附近由模型能够生成 power-law behaviour 的 volatility skew.
模型的一大难点来自于分形布朗运动的非马尔科夫性.
本文介绍两种方法
- one-step approach : 直接学习从隐含波动率曲面到模型参数的映射,
- two-step approach : 第一步学习从模型参数到期权价格的映射,然后根据实际市场价格校准模型.又分为 point-wise approach 和 grid-wise approach,前者将行权价和到期日作为输入,后者事先设定好这两项.
二.模型校准概述(未使用神经网络)
校准(calibration)意思是调整模型参数以使得模型曲面符合由欧式期权通过BS公式计算出的经验隐含波动率曲面.
假设模型有一个参数集 决定, i.e.,由 .进一步,我们假设期权由参数集 决定.E.g.,对看涨看跌期权我们有 ,分别为到期日和 log-moneyness.有些参数由市场观测得到,如现价、利率等,不在校准过程中.定价映射为 带参数 的模型中带参数 的期权的价格.我们通过 给定了有限子集 以及所有可能的期权参数对应的期权价格.校准是决定模型参数以使模型价格 和市场价格 在给定距离度量下最小,i.e.:
事实上,最常用的 是加权最小二乘: 这里的权重 反映了 对应期权的重要性以及 的可靠性.例如可以选择 bid-ask spread 的倒数.
只要模型参数比 少,此时就是超定(overdetermined)的非线性最小二乘问题,通常采用数值迭代的方法解决,如 Levenberg-Marquardt(LM)算法.
rBergomi :表示为 ,参数 ,例如可以设为 模型基于如下系统 其中 是 Hurst 系数,, 是 Wick exponential, 表示初始forward variance curve, 和 是以 相关的布朗运动.
三.深度校准
3.1 one-step approach
Hernandez A. Model calibration with neural networks[J]. Available at SSRN 2812140, 2016.
直接学习校准过程,即将模型参数视作市场价格(隐含波动率)的函数,i.e. 更具体地,训练神经网络基于标签数据 , 及其对应标签
3.2 two-step approach
首先学习定价映射,将模型参数映射为市场价格(或隐含波动率),然后使用标准校准方法进行校准.我们用 表示 是 的通过神经网络得到的近似.然后第二步我们进行校准
两步方法相较而言最大的好处如下:
- 神经网络只负责期权定价,所以能用人工数据来训练.
- 自然地将误差分为定价误差和模型误差.神经网络表现和模型对市场适应性做出的调整相互独立.
3.2.1 two-step approach: 逐点训练(pointwise)和基于网格(grid-based)训练
In this section, we examine its advantages and present an analysis of the objective function with the goal to enhance learning performance. Within this framework, the pointwise approach has the ability to asses the quality of using Monte Carlo or PDE methods, and indeed it is superior training in terms of robustness.
Pointwise learning
step 1:学习映射 即上述(2)式令 .在标准化期权(vanilla option,)情况下,我们可以直接学习隐含波动率映射 ,而不是期权定价的映射 .用 表示神经网络,最优化问题如下:
Step 2: 解决经典的模型校准步骤:
这里 或者 被替换成 step1 中的近似网络 .
第一步中,关键在于训练数据和网络结构的选择.训练数据在于选择 和 的‘先验’的、有实际意义的分布.
Implicit & grid-based learning
用 记关于到期日和行权价的网格.则
step 1:学习映射 ,输入是 ,输出是 这样的 网格. 取值在 中,其中 strikes maturities 最优化问题变为如下: 其中 . Step 2:
这里期权的参数 是固定了的,不再是学习的一部分.
3.2.2 pointwise versus grid-based
- 最大的不同在于 grid-based 在遇到不在网格上的 T,K 时需要手动插值
- grid-based 方法自然地有 reduction of variance ,
- pointwise 中对使样本符合实际金融数据的操作更简单,改变采样的分布.而 grid-wise 则是通过改变权重或者网格密度.
- grid-based 方法可以看做是一种降低维度的操作,将输入的维度转移到了输出的维度.
四.Pratical implementation
4.1 网络结构与训练
- 隐藏层为 3 层的全连接前馈神经网络,每层 30 个结点
- 输入维度记
- 输出维度为
- 总共有 个参数.
- 激活函数选择 Elu, ,梯度下降选择 Adam.
4.2 校准
使用第二节中讲的 LM 等算法.
五.数值实验
5.1定价近似网络的速度和精确度
※Deep learning volatility: a deep neural network perspective on pricing and calibration in (rough)volatility models※
fBm的 Monte-Carlo 模拟
1.理论基础
Notations: 在单位区间 表示连续函数空间, 表示 -Hölder 连续函数空间, . 和 是 上连续可微和有界连续可微函数空间.
1.1. Hölder spaces and fractional operators
For , the -Hölder space , with the norm is a non-separable Banach space. Following the spirit of Riemann-Liouville fractional operators recalled in Appendix , we introduce the class of Generalised Fractional Operators (GFO). For any we introduce the intervals , and the space , for any
Definition 1.1. For any and , the GFO associated to is defined on as We shall further use the notation , for any . Of particular interest in mathematical finance are the following kernels and operators:
Proposition 1.2. For any and , the operator is continuous.
We develop here an approximation scheme for the following system, generalising the concept of rough volatility in the context of mathematical finance, where the process represents the dynamics of the logarithm of a stock price process: with , and the (strong) solution to the stochastic differential equation
where denotes the state space of the process , usually or The two Brownian motions and , defined on a common filtered probability space , are correlated by the parameter , and the functional is assumed to be smooth on This is enough to ensure that the first stochastic differential equation is well defined. It remains to formulate the precise definition for (Proposition 1.4) to fully specify the system (1.3) and clarify the existence of solutions. Existence and (strong) uniquess of a solution to the second in (1.4) is guaranteed by the following standard assumption :
Assumption 1.3. There exist such that, for all
Proposition 1.4. For any ,the equality holds almost surely for .
Example 1.5. This example is the rough Bergomi model introduced by Bayer, Friz and Gatheral, where with and is the Wick stochastic exponential. This corresponds exactly to with and
1.2 The approximation scheme
We now move on to the core of the project, namely an approximation scheme for the system (1.3). The basic ingredient to construct approximating sequences is a family of iid random variables, which satisfies the following assumption: Assumption 1.6. The family forms an iid sequence of centered random variables with finite moments of all orders and
Following Donsker and Lamperti, we first define, for any , the approximating sequence for the driving Brownian motion as As will be explained later, a similar construction holds to approximate the process : where and Here and satisfy Assumption , with appropriate correlation structure between the pairs that will be made precise later. We shall always use to denote the sequence generating and the one generating . Consequently, we deduce an approximating scheme (up to the interpolating term which decays to zero by Chebyshev's inequality) for as All the approximations above, as well as all the convergence statements below should be understood pathwise, but we omit the dependence in the notations for clarity. The main result here is a convergence statement about the approximating sequence . As usual in weak convergence analysis, convergence is stated in the Skorokhod space of càdlàg processes equipped with the Skorokhod topology. Theorem 1.7. The sequence converges weakly to in . The construction of the proof allows to extend the convergence to the case where is a -dimensional diffusion without additional work. The proof of the theorem requires a certain number of steps: we start with the convergence of the approximation in some Hölder space, which we translate, first into convergence of the stochastic integral in , then, by continuity of the mapping , into convergence of the sequence . All these ingredients are detailed in Section 1.3 below. Once this is achieved, the proof of the theorem itself is relatively straightforward.
1.3. Monte-Carlo.
Theorem 1.7 introduces the theoretical foundations of Monte-Carlo methods (in particular for path-dependent options) for rough volatility models. In this section we give a general and easy-to-understand recipe to implement the class of rough volatility models (1.3). For the numerical recipe to be as general as possible, we shall consider the general time partition on with .
Algorithm 1.8 (Simulation of rough volatility models). (1) Simulate two matrices and with ; (2) simulate M paths of viad and also compute (3) Simulate paths of the fractional driving process using The complexity of this step is in general of order (see Appendix for details). However, this step is easily implemented using discrete convolution with complexity (see Algorithm [B.4 in Appendix for details in the implementation). With the vectors and for , we can write , for , where represents the discrete convolution operator. (4) Use the forward Euler scheme to simulate the log-stock process, for all , as
Remark:
- When , we may skip step (2) and replace by on step (33).
- Step (3) may be replaced by the Hybrid scheme algorithm 11 only when .
Antithetic variates in Algorithm 1.8 are easy to implement as it suffices to consider the uncorrelated random vectors and , for Then and , for , constitute the antithetic variates, which significantly improves the performance of the Algorithm 1.8 by reducing memory requirements, reducing variance and accelerating execution by exploiting symmetry of the antithetic random variables.
1.3.1 Enhancing performance. A standard practice in Monte-Carlo simulation is to match moments of the approximating sequence with the target process. In particular, when the process is Gaussian, matching first and second moments suffices. We only illustrate this approximation for Brownian motion: the left-point approximation may be modified to match moments as where is chosen optimally. Since the kernel is deterministic, there is no confusion with the Stratonovich stochastic integral, and the resulting approximation will always converge to the Itô integral. The first two moments of read The first moment of the approximating sequence 1.8 is always zero, and the second moment reads Equating the theoretical and approximating quantities we obtain for , so that the optimal evaluation point can be computed as In the Riemann-Liouville fractional Brownian motion case, , and the optimal point can be computed in closed form as
1.3.2 Reducing Variance.
As Bayer, Friz and Gatheral, a major drawback in simulating rough volatility models is the very high variance of the estimators, so that a large number of simulations are needed to produce a decent price estimate. Nevertheless, the rDonsker scheme admits a very simple conditional expectation technique which reduces both memory requirements and variance while also admitting antithetic variates. This approach is best suited for calibrating European type options. We consider and the natural filtrations generated by the Brownian motions and In particular the conditional variance process is deterministic. As discussed by Romano and Touzi, and recently adapted to the rBergomi case by McCrickerd and Pakkanen, we can decompose the stock price process as and notice that Thus becomes log-normal and the Black-Scholes closed-form formulae are valid here (European, Barrier options, maximum,...). The advantage of this approach is that the orthogonal Brownian motion is completely unnecessary for the simulation, hence the generation of random numbers is reduced to a half, yielding proportional memory saving. Not only this, but also this simple trick reduces the variance of the Monte-Carlo estimate, hence fewer simulations are needed to obtain the same precision. We present a simple algorithm to implement the rDonsker with conditional expectation and assuming that .
Algorithm 1.9 (Simulation of rough volatility models with Brownian drivers). Consider the equidistant grid . (1) Draw a random matrix with unit variance, and create antithetic variates ; (2) Create a correlated matrix as above; (3) Simulate paths of the fractional driving process using discrete convolution: and store in memory for each (4) use the forward Euler scheme to simulate the log-stock process, for each , as (5) Finally, we have for we may compute any option using the Black-Scholes formula. For instance a Call option with strike would be given by for , where and Thus, the output of the model would be
The algorithm is easily adapted to the case of general diffusions as drivers of the volatility (see Algorithm 1.8 step 2). Algorithm 1.8 is obviously faster than 1.9, especially when using control variates. Nevertheless, with the same number of paths, Algorithm 1.9 remarkably reduces the Monte-Carlo variance, meaning in turn that fewer simulations are needed, making it very competitive for calibration.
2.传统cholesky分解法模拟
If you need to generate correlated Gaussian distributed random variables where is the vector you want to simulate, the vector of means and the given covariance matrix, 1.you first need to simulate a vector of uncorrelated Gaussian random variables, 2.then find a square root of , i.e. a matrix such that . Your target vector is given by A popular choice to calculate is the Cholesky decomposition.
而对于本 rBergomi 模型,
where is a Volterra processt with the scaling property . So far behaves just like . However, the dependence structure is different. Specifically, for where, for and with , where denotes the confluent hypergeometric function. Remark The dependence structure of the Volterra process is markedly different from that of with the MolchanGolosov kernel given by for some constant In particular, for small , correlations drop precipitously as the ratio moves away from 1 .
We also need covariances of the Brownian motion with the Volterra process . With , these are given by and where for future convenience, we have defined the constant, These two formulae may be conveniently combined as Lastly, of course, for . With the number of time steps and the number of simulations, our rBergomi model simulation algorithm may then be summarized as follows.
- Construct the joint covariance matrix for the Volterra process and the Brownian motion and compute its Cholesky decomposition.
- For each time, generate iid normal random vectors and multiply them by the lower triangular matrix obtained by the Cholesky decomposition to get a matrix of paths of and with the correct joint marginals.
- With these paths held in memory, we may evaluate the expectation under of any payoff of interest.
we simulate the process
import numpy as np
import matplotlib.pyplot as plt
import scipy.special as special
def fBm_path_chol(grid_points, M, H, T):
"""
@grid_points: # points in the simulation grid
@H: Hurst Index
@T: time horizon
@M: # paths to simulate
"""
assert 0<H<1.0
## Step1: create partition
X=np.linspace(0, 1, num=grid_points)
# get rid of starting point
X=X[1:grid_points]
## Step 2: compute covariance matrix
Sigma=np.zeros((grid_points-1,grid_points-1))
for j in range(grid_points-1):
for i in range(grid_points-1):
if i==j:
Sigma[i,j]=np.power(X[i],2*H)/2/H
else:
s=np.minimum(X[i],X[j])
t=np.maximum(X[i],X[j])
Sigma[i,j]=np.power(t-s,H-0.5)/(H+0.5)*np.power(s,0.5+H)*special.hyp2f1(0.5-H, 0.5+H, 1.5+H, -s/(t-s))
## Step 3: compute Cholesky decomposition
P=np.linalg.cholesky(Sigma)
## Step 4: draw Gaussian rv
Z=np.random.normal(loc=0.0, scale=1.0, size=[M,grid_points-1])
## Step 5: get V
W=np.zeros((M,grid_points))
for i in range(M):
W[i,1:grid_points]=np.dot(P,Z[i,:])
#Use self-similarity to extend to [0,T]
return W*np.power(T,H)
3.rDonker方法
def fBm_path_rDonsker(grid_points, M, H, T, kernel="optimal"):
"""
@grid_points: # points in the simulation grid
@H: Hurst Index
@T: time horizon
@M: # paths to simulate
@kernel: kernel evaluation point use "optimal" for momen-match or "naive" for left-point
"""
assert 0<H<1.0
## Step1: create partition
dt=1./(grid_points-1)
X=np.linspace(0, 1, num=grid_points)
# get rid of starting point
X=X[1:grid_points]
## Step 2: Draw random variables
dW = np.power(dt, H) *np.random.normal(loc=0, scale=1, size=[M, grid_points-1])
## Step 3: compute the kernel evaluation points
i=np.arange(grid_points-1) + 1
# By default use optimal moment-matching
if kernel=="optimal":
opt_k=np.power((np.power(i,2*H)-np.power(i-1.,2*H))/2.0/H,0.5)
# Alternatively use left-point evaluation
elif kernel=="naive" :
opt_k=np.power(i,H-0.5)
else:
raise NameError("That was not a valid kernel")
## Step 4: Compute the convolution
Y = np.zeros([M, n])
for i in range(int(M)):
Y[i, 1:n] = np.convolve(opt_k, dW[i, :])[0:n - 1]
#Use self-similarity to extend to [0,T]
return Y*np.power(T,H)
※使用GAN对LSV模型的校准※
This means parameterizing the model pool in a way which is accessible for machine learning techniques and interpreting the inverse problem as a training task of a generative network, whose quality is assessed by anadversary.We pursue this approach in the presentarticle and use as generative models so-called neural stochastic differential equations (SDE),which just means to parameterize the drift and volatility of an Itˆo-SDE by neural networks.
1.介绍
文中指的neural SDE即通过神经网络来对Ito-SDE的漂移项和波动率进行参数化.
这里考虑的某资产的折现后价格过程(discounted price process) :
其中 是某个 中取值的随机过程, 称为杠杆函数(Leverage function)取决于 和资产当前价格.
的选取非常重要,需要很好地校准市场上观测到的隐含波动率.故 需要满足如下条件:
其中 指 Dupire 的local volatility function.注意到(1.1)是 的隐式方程,因为 中需要 .故此时 满足的SDE也成为了一个McKean-Vlasov SDE.
本文采用了 fully data-driven 方法,规避了其他计算 Dupire 局部波动率的方法中必须的对波动率曲面插值的做法,即此方法只需离散数据.
令 , 为不同期权的到期日.使用神经网络族 将杠杆函数参数化,参数为 ,i.e.
于是有了neural SDE的生成模型组(generative model class),即使用带参数 的神经网络来参数化漂移项 和波动率项 ,i.e.
本文中,没有漂移项,波动率项如下所示:
依次对每个到期日,参数优化采用如下的校准法则: 其中 是期权的数目, 和 是模型与市场分别的价格.
对固定的 , 是非线性非负凸函数满足 且 对 ,衡量模型和市场价的距离. 某种权重,参数 扮演了对抗(adversarial)的部分,注意到 和 都受 控制.本文中 采用的是 Cont R, Ben Hamida S. Recovering volatility from option prices by evolutionary optimization[J]. 2004.中的 vega-type.
2.VARIANCE REDUCTION FOR PRICING AND CALIBRATION VIA HEDGING AND DEEP HEDGING
介绍在蒙特卡洛定价和校准中利用对冲投资组合作为控制变量的方差缩减技术.在 LSV 校准中非常重要.
考虑有限时域 ,已折现的市场中有 个交易中的金融产品 ,它是在某个概率空间 上在 中取值的随机变量. 是风险中性测度, 假设是右连续的.特别的,假设 是有右连左极路径的 维平方可积鞅.
令 是 可测的随机变量,表示表示某个欧式期权在到期日 的支付.那么通常的对这个期权价格的 Monte Carlo 估计是: 其中, 是以分布 , i.i.d 的.可以简单改造这个估计,加上关于 的随机积分.考虑一个策略 和某个常数 .用 记关于 的随机积分,考虑如下估计: 其中, 是以分布 i.i.d 的.则对于任意的 和 ,这个估计仍是期权价格的无偏估计,因为随机积分的期望消失了.记 则 的方差为: 在以下取法下达到最小 此时 特别地,在沿路径完美对冲的情形下, a.s.,有 和 ,此时 因此,找到一个好的近似对冲投资组合使得 大是很重要的.
2.1 Black&Scholes Delta Hedge
In many cases, of local stochastic volatility models as of form (1.1) and options depending only on the terminal value of the price process, a Delta hedge of the BlackâĂŞScholes model works well.
令 , 是 BS 模型下 时刻的价格.对冲策略为:
2.2 Hedging Strategies as Neural Networks-Deep Hedging
在对冲产品数很多等情况下时,可以将对冲策略用神经网络参数化.令期权的支付是对冲产品最终价值的函数,i.e.,.在马尔科夫模型中,可以用函数表示对冲策略: 对应这样一个神经网络:. 是网络参数.根据Buehler H, Gonon L, Teichmann J, et al. Deep hedging[J]. Quantitative Finance, 2019, 19(8): 1271-1291. 给定 的最优对冲可以如下计算 是凸的损失函数.
为了解决这个最优问题,采用随机梯度下降,随机目标函数 为: 记最优的参数 和最优对冲策略 .
假定激活函数和凸损失函数是光滑的.下面要证明 的梯度是: i.e.,我们可以把梯度移到随机积分中.为此,我们要使用下述定理.
定理 2.1:,令 是
Theorem 2.1. For ling, let be a solution of a stochastic differential equation as described in Theorem with drivers , functionally Lipschitz operators , and a process , which is here for all simply for some constant vector , i.e. Let be a map, such that the bounded càglàd process converges to , then holds true.
推论 2.2:,令 为对冲产品过程 的离散,使得定理 2.1 中的条件都满足.对应的对冲策略 由神经网络 给出,其中网络的激活函数有界 ,且导数有界.那么
(i) 随机积分在 点关于 导数 满足
(ii) 若当 时, ucp 收敛到 ,则离散积分的方向导数,i.e. 随着离散刻度 收敛到
ucp means uniform convergence on compacts in probability,i.e.,if for all . The notation is sometimes used, and is said to converge ucp to .
3. LSV的校准
考虑定义在某个概率空间 上的(1.1)LSV模型, 是风险中性测度.假定随机过程 固定.所以实际中我们可以先令 来近似校准其他参数并固定他们.
主要目标是确定符合市场数据的杠杆函数 ,根据通用近似定理(universal approximation properties),对其参数化.令 为欧式看涨期权的到期日.将 用如下神经网络近似 其中 , .方便起见,通常省略 .当我们写 时, 表示 时刻前所有的参数 .
训练过程中,我们需要计算 LSV 过程关于 的导数.以下结果可以看做 对应的链式法则.从附录 A 推导而来.
定理 3.1:令 为(3.1)形式,神经网络 有界且 ,导数有界且 Lipschitz 连续.则关于 在 点处的导数满足: 初值为 0.这个可以通过常数变易来解,i.e. 其中 表示随机指数(stochastic exponential).
Remark
(i) 只看存在唯一性的话, 为 (3.1) 形式,那么神经网络 有界以及 Lipschitz 足够了,.
(ii) 公式 (3.3) 可以用来倒向传播.
定理 3.1 保证了导数过程的存在唯一性.这也保证了基于梯度搜索的学习算法的建立.
下面叙述如何具体优化.为了记号方便,省略权重 和损失函数 对应的参数 .对每个到期日 ,我们假定有 个期权,行权价为 .对第 个到期日,校准函数的形式为 回忆 指的是对应到期日 和行权价 的模型期权价格. 是某个非负非线性凸的损失函数满足 对 . 是权重.
我们通过迭代地计算最优化问题(3.5),从 和 出发,计算 ,然后解决对应 的(3.5).为了简便记号,去掉 ,考虑一般的到期日 ,(3.5)变为 模型价格由下式给出 我们有 ,其中 那么校准问题变为寻找最小的 因为 是非线性函数,不是 B.1 中的期望形式,标准的随机梯度下降方法不能直接用.我们通过第二节中讲的对冲控制变量 (hedge control variates) 解决这个问题.
3.1 极小化校准方程
考虑标准的对(3.8) 的 Monte-Carlo 模拟: 对 i.i.d 的样本 .Monte-Carlo 误差以 递减.模拟次数 必须很大 .因为由于 非线性,随机梯度下降不能直接使用,所以看起来要计算整个函数 的梯度来最小化(3.9).但 ,这一做法计算成本太大且不稳定,因为要计算 项的和的导数.
一个方便的做法是应用对冲控制变量来降低方差,可以将 Monte-Carlo 的样本数 降为大约 .
假定我们有 个对冲产品(包含价格过程 ),用 表示,为 下的平方可积鞅,在 下取值.对 ,策略 使得 , 为常数,定义 则校准函数(3.8)和(3.9)可以通过替换 为 来定义,变为最小化 对此,我们应用如下梯度下降的变种:从初始猜测 出发,迭代计算 对某个学习率 ,i.i.d 样本 .其中 是基于梯度待确定的量,样本在每次迭代中可以一样,可以另取.本文中另取.
最简单情形下,可以令
注意到(3.10)中随机积分项的导数计算通常是昂贵的.我们进行下述改造.令 定义 : 然后令 注意到基于倒向传播,这一项计算起来是很简单的.Moreover, leaving the stochastic integral away in the inner derivative is justified by its vanishing expectation. During the forward pass, the stochastic integral terms are included in the computation; however the contribution to the gradient (during the backward pass) is partly neglected, which can e.g. be implemented via the tensorflow stop_gradient function.
关于对冲策略的选择,我们可以按照 2.2 节中的方法将其用神经网络参数化,并通过下式计算最优的权重 : 对 i.i.d 样本 和损失函数 .此处 这意味着迭代两个优化步骤,i.e.,优化(3.11)中的 (固定 ) 和(3.14)中的 (固定 ).
4. 数值实验流程
实际使用的 SABR-LSV 模型如下 参数为 ,初值有 . 和 是两个相关的布朗运动.
Remark:一般使用的是关于 的对数价格 .故模型也可写为: 注意到 是一个几何布朗运动,也就是说它有表达式:
生成样本
在已有文献中,有推荐的局部波动率函数族 如下: 其中 且参数满足如下约束: 令 , 如下定义: 文中作者修改为: 其中 注意到 与 有关.所以在做 Monte Carlo 模拟时,我们将 替换为 , 是 Monte Carlo 模拟的时间间隔. What is left to be specified are the parameters 模型变为: 上式是用来生成人工市场价格样本的.
所以我们实际的做法是随机对 中的 采样再根据 (1) 计算出价格,然后对 SABR-LSV 模型进行校准,i.e. 寻找使模型符合上述价格的参数 , 以及 .
到期日 ,每个到期日 对应行权价为 .用 Monte-Carlo 模拟以 间隔计算价格.
具体如下:
- 在 下对 以给定分布进行模拟.
- 对每个 ,根据(1)式计算 and strikes for and 对应的欧式期权的价格.每个 分别使用不同的 条布朗运动轨道.
- 保存这些价格数据
准备做的工作(弃案)
寻找最适合市场波动率曲面的“复合”模型,即假设市场波动率曲面实际是由一些波动率模型的凸组合决定的.
回忆:波动率曲面即隐含波动率以 :time to maturity 和 :log-moneyness 为自变量构成的曲面. 例如,我们可以假设当前 ,其中 .
大致做法
记号:分别以 、 记 Heston 和 rough 模型的参数集,以 、 记该两者通过神经网络训练得到的从模型参数到市场价格(隐含波动率)的映射,、 为前述凸组合系数.
我们这里考虑直接通过神经网络来学习市场波动率曲面 到凸组合系数 的映射.
我们对两个模型的参数以及 分别均匀采样,然后根据两者的模型分别模拟出不同凸组合下两者的复合波动率曲面,但要注意的是两者采用同一个参数 (即两个标准布朗运动的相关系数)并且一个凸组合下两模型使用同一条 Monte-Carlo 轨道.这时,忽略掉模型的参数,我们有了带有 标签的许多波动率曲面样本,我们利用前述 grid-based 的方法通过神经网络学习从波动率曲面到凸组合系数的映射.
知道了 后,如何校准出两个模型分别的参数?
LSV-ROUGH 模型的校准
模型:
杠杆函数:
主要共有两个神经网络,一个负责 Rough 的部分,一个负责 LSV 的部分.
一方面,Rough 部分的网络对应的即 Bayer 提出的 two-step 校准方法,即如下模型 ((1.1)中 时): 对应的从模型参数到模型对应价格的映射的网络.只需用人工模拟数据训练一次后,网络就固定住了,在校准等步骤中是不会再变动的.
回忆:用 记关于到期日和行权价的网格.则 step 1:学习映射 ,输入是 ,输出是 这样的 网格. 取值在 中,其中 strikes maturities 最优化问题变为如下: 其中 . Step 2:
另一方面,我们将 这个函数用网络近似,这个网络中的参数是随着校准不断变动的.具体地,令 为欧式看涨期权的到期日.将 用如下神经网络近似 其中 , .
为了记号方便,省略权重 和损失函数 对应的参数 .对每个到期日 ,我们假定有 个期权,行权价为 .对第 个到期日,校准函数的形式为 回忆 指的是对应到期日 和行权价 的模型期权价格. 是某个非负非线性凸的损失函数满足 对 . 是权重.
我们通过迭代地计算最优化问题(1.3),从 和 出发,计算 ,然后解决对应 的(1.3).为了简便记号,去掉 ,考虑一般的到期日 ,(1.3)变为 模型价格由下式给出 我们有 ,其中 那么校准问题变为寻找最小的 我们通过第二节中讲的对冲控制变量 (hedge control variates) 解决这个问题.
考虑标准的对(3.8) 的 Monte-Carlo 模拟: 对 i.i.d 的样本 .Monte-Carlo 误差以 递减.模拟次数 必须很大 .因为由于 非线性,随机梯度下降不能直接使用,所以看起来要计算整个函数 的梯度来最小化(3.9).但 ,这一做法计算成本太大且不稳定,因为要计算 项的和的导数.
一个方便的做法是应用对冲控制变量来降低方差,可以将 Monte-Carlo 的样本数 降为大约 .
假定我们有 个对冲产品(包含价格过程 ),用 表示,为 下的平方可积鞅,在 下取值.对 ,策略 使得 , 为常数,定义 则校准函数(3.8)和(3.9)可以通过替换 为 来定义,变为最小化
算法1:模型的校准步骤
-
# 初始化网络参数
-
-
# 定义初始模拟轨道数和初始步骤值
-
-
# 定义时间离散间隔和误差容忍度
-
-
:
-
-
# 计算此次切片的初始正规化权重
-
-
-
-
-
-
-
-
-
-
-
-
-
-
算法2:超参的更新
附录
证明:
首先定理 A.2 暗示了 的解存在唯一性.这里驱动过程是一维的 .事实上,若 有界,对 左极右连,对 Lipschitz 连续以一个与 无关的 Lipschitz 常数. 为 functionally Lipschitz,得到结论.这些条件由 的形式和 的条件保证.
为了证明导数过程的形式,我们对如下系统应用定理 A.3: 和 以及 在定理 A.3 中,. 为 ucp 收敛到 事实上,, 等度连续.因此,点点收敛暗示对 的一致连续.This together with being piecewise constant in yields: whence ucp convergence of the first term in (3.4). The convergence of term two is clear. The one of term three follows again from the fact that the family is equicontinuous, which is again a consequence of the form of the neural networks.
By the assumptions on the derivatives, is functionally Lipschitz. Hence Theorem A.2 yields the existence of a unique solution to (3.2) and Theorem A.3 implies convergence.
Proof. Consider the extended system and where we obtain existence, uniqueness and stability for the second equation by Theorem A.3, and from where we obtain ucp convergence of the integrand of the first equation: since stochastic integration is continuous with respect to the ucp topology we obtain the result.
文献
- 首次在波动率校准中运用神经网络 Hernandez A. Model calibration with neural networks[J]. Available at SSRN 2812140, 2016.
- rough波动率模型的神经网络校准 Bayer C, Horvath B, Muguruza A, et al. On deep calibration of (rough) stochastic volatility models[J]. arXiv preprint arXiv:1908.08806, 2019.和 Horvath B, Muguruza A, Tomas M. Deep learning volatility: a deep neural network perspective on pricing and calibration in (rough) volatility models[J]. Quantitative Finance, 2021, 21(1): 11-27. Github 代码
- LSV模型GAN校准 Cuchiero C, Khosrawi W, Teichmann J. A generative adversarial network approach to calibration of local stochastic volatility models[J]. Risks, 2020, 8(4): 101. Github 代码
- 损失函数中不同期权权重取法 Cont R, Ben Hamida S. Recovering volatility from option prices by evolutionary optimization[J]. 2004.
- fBm的MC模拟 Horvath B, Jacquier A J, Muguruza A. Functional central limit theorems for rough volatility[J]. Available at SSRN 3078743, 2017. Github 代码
- rBergomi提出 Bayer C, Friz P, Gatheral J. Pricing under rough volatility[J]. Quantitative Finance, 2016, 16(6): 887-904.
wing model
一. 模型建立
波动率模型 Wing Model(知乎, maomao.run)

Wing Model是期权交易中常见的一种对波动率进行建模的方法. 它通过调整参数, 将市场中一个系列的期权的隐含波动率拟合到一个曲线上. Wing Model 把隐含波动率曲线分为 6 个区域, 以 ATM Forward(期权对应标的远期价)为中心, 左边区域 1, 2, 3 构成 Put Wing, 右边区域 4, 5, 6 构成 Call Wing. 其中, 区域 1, 6 为常数波动率部分, 区域 3, 4 为抛物线部分, 区域 2, 5 则为过渡部分(其实也是抛物线). x 轴为期权的行权价(或者对数化行权价), y 轴为期权波动率.
| 名称 | 参数 | 描述 |
|---|---|---|
| atm forward | atm | 期权对应合成期货价 |
| volatility reference | vr(vc) | 中心点参考波动率 |
| slope reference | sr(sc) | 中心点参考斜率,也是 Put Wing3 和 Call Wing4 抛物线共同的一次项系数 |
| put curvature | pc | Put Wing3 抛物线的二次项系数 |
| call curvature | cc | Call Wing4 抛物线的二次项系数 |
| down cutoff | dc | Put Wing 2和3 交界点 x 值, dc<0 |
| up cutoff | uc | Call Wing 4和5 交界点 x 值, uc>0 |
| down smoothing range | dsm | Put Wing 用来计算 1 和 2 交界点 x 值的参数 |
| up smoothing range | usm | Call Wing 用来计算 5 和 6 交界点 x 值的参数 |
| skew swimmingness rate | ssr | 斜率游离系数, 取值范围 , 在计算合成期货价格时, 该参数用来调节 atm 和 ref 的比例 |
| volatility change rate | vcr | 波动率变化系数 |
| slope change rate | scr | 斜率变化系数 |
| reference price | ref | ssr<100 时, 需要定义 ref 用来描述 vcr 和 scr 对中心点波动率和中心点斜率的影响 |
其中,atm已知,dc、dsm、uc、usm一般为经验值,vcr、scr、ssr一般使用默认值0、0、100,vc、sc、pc、cc待拟合.
代码中取的几个默认参数是:dc=-0.2, uc=0.2, dsm=0.5, usm=0.5.
- 以50etf为例,当前2022-12-13 13:12近月合成期货为2.682,那么中间两个区域边界分别为;
- 以300etf为例,当前2022-12-13 13:12近月合成期货为4.012,那么中间两个区域边界分别为;
如果 dc=-0.15, uc=0.15 , 50中间区域边界为 [2.31, 3.12], 300中间区域边界为 [3.45, 4.66].
除了上述参数, 还需要一些中间参数以便于表示最终函数. 其中,
合成期货价格: 是中心点波动率: 为中心点斜率: 在常规使用的时候,下列参数一般取默认值: 将默认值带入中间参数公式,可得: 因此我们平常使用,只需要 这9个参数就 够了. 依照参数的定义, 我们可以定义出区域之间的 5 个分隔点的 x 坐标, 从左到右依次为: 把这个五个点对数化,即: 其中 为原 坐标, 为合成期货价格. 对数化后我们重新定义出 5 个分隔点新的 x 坐标, 从左到右依次为:
函数求解
区域 3 和区域 4 的抛物线函数是由参数确定的: 根据各区域连接处导数一致列方程即可求出其他区域的表达式:
二. 根据无套利进行推导
Wing-Model Volatility Skew Manager, 子非鱼根据Jim Gatheral 在Arbitrage-free SVI volatility surfaces 提到的无套利观点和算法对 wing-model 进行公式推导分析,在拟合的基础上进一步根据按定义域分 6 块给出 6 个约束条件判断拟合曲线无蝶式套利.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2020/5/22 11:12 PM
# @Author : 稻草人
# @contact : dybeta2021@163.com
# @File : wing_model.py
# @Desc : orc wing model
from functools import partial
from numpy import ndarray, array, arange, zeros, ones, argmin, minimum, maximum, clip
from numpy.linalg import norm
from numpy.random import normal
from scipy.interpolate import interp1d
from scipy.optimize import minimize
class WingModel(object):
def skew(moneyness: ndarray, vc: float, sc: float, pc: float, cc: float, dc: float, uc: float, dsm: float,
usm: float) -> ndarray:
"""
:param moneyness: converted strike, moneyness
:param vc:
:param sc:
:param pc:
:param cc:
:param dc:
:param uc:
:param dsm:
:param usm:
:return:
"""
assert -1 < dc < 0
assert dsm > 0
assert 1 > uc > 0
assert usm > 0
assert 1e-6 < vc < 10 # 数值优化过程稳定
assert -1e6 < sc < 1e6
assert dc * (1 + dsm) <= dc <= 0 <= uc <= uc * (1 + usm)
# volatility at this converted strike, vol(x) is then calculated as follows:
vol_list = []
for x in moneyness:
# volatility at this converted strike, vol(x) is then calculated as follows:
if x < dc * (1 + dsm):
vol = vc + dc * (2 + dsm) * (sc / 2) + (1 + dsm) * pc * pow(dc, 2)
elif dc * (1 + dsm) < x <= dc:
vol = vc - (1 + 1 / dsm) * pc * pow(dc, 2) - sc * dc / (2 * dsm) + (1 + 1 / dsm) * (
2 * pc * dc + sc) * x - (pc / dsm + sc / (2 * dc * dsm)) * pow(x, 2)
elif dc < x <= 0:
vol = vc + sc * x + pc * pow(x, 2)
elif 0 < x <= uc:
vol = vc + sc * x + cc * pow(x, 2)
elif uc < x <= uc * (1 + usm):
vol = vc - (1 + 1 / usm) * cc * pow(uc, 2) - sc * uc / (2 * usm) + (1 + 1 / usm) * (
2 * cc * uc + sc) * x - (cc / usm + sc / (2 * uc * usm)) * pow(x, 2)
elif uc * (1 + usm) < x:
vol = vc + uc * (2 + usm) * (sc / 2) + (1 + usm) * cc * pow(uc, 2)
else:
raise ValueError("x value error!")
vol_list.append(vol)
return array(vol_list)
def loss_skew(cls, params: [float, float, float], x: ndarray, iv: ndarray, vega: ndarray, vc: float, dc: float,
uc: float, dsm: float, usm: float):
"""
:param params: sc, pc, cc
:param x:
:param iv:
:param vega:
:param vc:
:param dc:
:param uc:
:param dsm:
:param usm:
:return:
"""
sc, pc, cc = params
vega = vega / vega.max()
value = cls.skew(x, vc, sc, pc, cc, dc, uc, dsm, usm)
return norm((value - iv) * vega, ord=2, keepdims=False)
def calibrate_skew(cls, x: ndarray, iv: ndarray, vega: ndarray, dc: float = -0.2, uc: float = 0.2, dsm: float = 0.5,
usm: float = 0.5, is_bound_limit: bool = False,
epsilon: float = 1e-16, inter: str = "cubic"):
"""
:param x: moneyness
:param iv:
:param vega:
:param dc:
:param uc:
:param dsm:
:param usm:
:param is_bound_limit:
:param epsilon:
:param inter: cubic inter
:return:
"""
vc = interp1d(x, iv, kind=inter, fill_value="extrapolate")([0])[0]
# init guess for sc, pc, cc
if is_bound_limit:
bounds = [(-1e3, 1e3), (-1e3, 1e3), (-1e3, 1e3)]
else:
bounds = [(None, None), (None, None), (None, None)]
initial_guess = normal(size=3)
args = (x, iv, vega, vc, dc, uc, dsm, usm)
residual = minimize(cls.loss_skew, initial_guess, args=args, bounds=bounds, tol=epsilon, method="SLSQP")
assert residual.success
return residual.x, residual.fun
def sc(sr: float, scr: float, ssr: float, ref: float, atm: ndarray or float) -> ndarray or float:
return sr - scr * ssr * ((atm - ref) / ref)
def loss_scr(cls, x: float, sr: float, ssr: float, ref: float, atm: ndarray, sc: ndarray) -> float:
return norm(sc - cls.sc(sr, x, ssr, ref, atm), ord=2, keepdims=False)
def fit_scr(cls, sr: float, ssr: float, ref: float, atm: ndarray, sc: ndarray,
epsilon: float = 1e-16) -> [float, float]:
init_value = array([0.01])
residual = minimize(cls.loss_scr, init_value, args=(sr, ssr, ref, atm, sc), tol=epsilon, method="SLSQP")
assert residual.success
return residual.x, residual.fun
def vc(vr: float, vcr: float, ssr: float, ref: float, atm: ndarray or float) -> ndarray or float:
return vr - vcr * ssr * ((atm - ref) / ref)
def loss_vc(cls, x: float, vr: float, ssr: float, ref: float, atm: ndarray, vc: ndarray) -> float:
return norm(vc - cls.vc(vr, x, ssr, ref, atm), ord=2, keepdims=False)
def fit_vcr(cls, vr: float, ssr: float, ref: float, atm: ndarray, vc: ndarray,
epsilon: float = 1e-16) -> [float, float]:
init_value = array([0.01])
residual = minimize(cls.loss_vc, init_value, args=(vr, ssr, ref, atm, vc), tol=epsilon, method="SLSQP")
assert residual.success
return residual.x, residual.fun
def wing(cls, x: ndarray, ref: float, atm: float, vr: float, vcr: float, sr: float, scr: float, ssr: float,
pc: float, cc: float, dc: float, uc: float, dsm: float, usm: float) -> ndarray:
"""
wing model
:param x:
:param ref:
:param atm:
:param vr:
:param vcr:
:param sr:
:param scr:
:param ssr:
:param pc:
:param cc:
:param dc:
:param uc:
:param dsm:
:param usm:
:return:
"""
vc = cls.vc(vr, vcr, ssr, ref, atm)
sc = cls.sc(sr, scr, ssr, ref, atm)
return cls.skew(x, vc, sc, pc, cc, dc, uc, dsm, usm)
class ArbitrageFreeWingModel(WingModel):
def calibrate(cls, x: ndarray, iv: ndarray, vega: ndarray, dc: float = -0.2, uc: float = 0.2, dsm: float = 0.5,
usm: float = 0.5, is_bound_limit: bool = False, epsilon: float = 1e-16, inter: str = "cubic",
level: float = 0, method: str = "SLSQP", epochs: int = None, show_error: bool = False,
use_constraints: bool = False) -> ([float, float, float], float):
"""
:param x:
:param iv:
:param vega:
:param dc:
:param uc:
:param dsm:
:param usm:
:param is_bound_limit:
:param epsilon:
:param inter:
:param level:
:param method:
:param epochs:
:param show_error:
:param use_constraints:
:return:
"""
vega = clip(vega, 1e-6, 1e6)
iv = clip(iv, 1e-6, 10)
# init guess for sc, pc, cc
if is_bound_limit:
bounds = [(-1e3, 1e3), (-1e3, 1e3), (-1e3, 1e3)]
else:
bounds = [(None, None), (None, None), (None, None)]
vc = interp1d(x, iv, kind=inter, fill_value="extrapolate")([0])[0]
constraints = dict(type='ineq', fun=partial(cls.constraints, args=(x, vc, dc, uc, dsm, usm), level=level))
args = (x, iv, vega, vc, dc, uc, dsm, usm)
if epochs is None:
if use_constraints:
residual = minimize(cls.loss_skew, normal(size=3), args=args, bounds=bounds, constraints=constraints,
tol=epsilon, method=method)
else:
residual = minimize(cls.loss_skew, normal(size=3), args=args, bounds=bounds, tol=epsilon, method=method)
if residual.success:
sc, pc, cc = residual.x
arbitrage_free = cls.check_butterfly_arbitrage(sc, pc, cc, dc, dsm, uc, usm, x, vc)
return residual.x, residual.fun, arbitrage_free
else:
epochs = 10
if show_error:
print("calibrate wing-model wrong, use epochs = 10 to find params! params: {}".format(residual.x))
if epochs is not None:
params = zeros([epochs, 3])
loss = ones([epochs, 1])
for i in range(epochs):
if use_constraints:
residual = minimize(cls.loss_skew, normal(size=3), args=args, bounds=bounds,
constraints=constraints,
tol=epsilon, method="SLSQP")
else:
residual = minimize(cls.loss_skew, normal(size=3), args=args, bounds=bounds, tol=epsilon,
method="SLSQP")
if not residual.success and show_error:
print("calibrate wing-model wrong, wrong @ {} /10! params: {}".format(i, residual.x))
params[i] = residual.x
loss[i] = residual.fun
min_idx = argmin(loss)
sc, pc, cc = params[min_idx]
loss = loss[min_idx][0]
arbitrage_free = cls.check_butterfly_arbitrage(sc, pc, cc, dc, dsm, uc, usm, x, vc)
return (sc, pc, cc), loss, arbitrage_free
def constraints(cls, x: [float, float, float], args: [ndarray, float, float, float, float, float],
level: float = 0) -> float:
"""蝶式价差无套利约束
:param x: guess values, sc, pc, cc
:param args:
:param level:
:return:
"""
sc, pc, cc = x
moneyness, vc, dc, uc, dsm, usm = args
if level == 0:
pass
elif level == 1:
moneyness = arange(-1, 1.01, 0.01)
else:
moneyness = arange(-1, 1.001, 0.001)
return cls.check_butterfly_arbitrage(sc, pc, cc, dc, dsm, uc, usm, moneyness, vc)
"""蝶式价差无套利约束条件
"""
def left_parabolic(sc: float, pc: float, x: float, vc: float) -> float:
"""
:param sc:
:param pc:
:param x:
:param vc:
:return:
"""
return pc - 0.25 * (sc + 2 * pc * x) ** 2 * (0.25 + 1 / (vc + sc * x + pc * x * x)) + (
1 - 0.5 * x * (sc + 2 * pc * x) / (vc + sc * x + pc * x * x)) ** 2
def right_parabolic(sc: float, cc: float, x: float, vc: float) -> float:
"""
:param sc:
:param cc:
:param x:
:param vc:
:return:
"""
return cc - 0.25 * (sc + 2 * cc * x) ** 2 * (0.25 + 1 / (vc + sc * x + cc * x * x)) + (
1 - 0.5 * x * (sc + 2 * cc * x) / (vc + sc * x + cc * x * x)) ** 2
def left_smoothing_range(sc: float, pc: float, dc: float, dsm: float, x: float, vc: float) -> float:
a = - pc / dsm - 0.5 * sc / (dc * dsm)
b1 = -0.25 * ((1 + 1 / dsm) * (2 * dc * pc + sc) - 2 * (pc / dsm + 0.5 * sc / (dc * dsm)) * x) ** 2
b2 = -dc ** 2 * (1 + 1 / dsm) * pc - 0.5 * dc * sc / dsm + vc + (1 + 1 / dsm) * (2 * dc * pc + sc) * x - (
pc / dsm + 0.5 * sc / (dc * dsm)) * x ** 2
b2 = (0.25 + 1 / b2)
b = b1 * b2
c1 = x * ((1 + 1 / dsm) * (2 * dc * pc + sc) - 2 * (pc / dsm + 0.5 * sc / (dc * dsm)) * x)
c2 = 2 * (-dc ** 2 * (1 + 1 / dsm) * pc - 0.5 * dc * sc / dsm + vc + (1 + 1 / dsm) * (2 * dc * pc + sc) * x - (
pc / dsm + 0.5 * sc / (dc * dsm)) * x ** 2)
c = (1 - c1 / c2) ** 2
return a + b + c
def right_smoothing_range(sc: float, cc: float, uc: float, usm: float, x: float, vc: float) -> float:
a = - cc / usm - 0.5 * sc / (uc * usm)
b1 = -0.25 * ((1 + 1 / usm) * (2 * uc * cc + sc) - 2 * (cc / usm + 0.5 * sc / (uc * usm)) * x) ** 2
b2 = -uc ** 2 * (1 + 1 / usm) * cc - 0.5 * uc * sc / usm + vc + (1 + 1 / usm) * (2 * uc * cc + sc) * x - (
cc / usm + 0.5 * sc / (uc * usm)) * x ** 2
b2 = (0.25 + 1 / b2)
b = b1 * b2
c1 = x * ((1 + 1 / usm) * (2 * uc * cc + sc) - 2 * (cc / usm + 0.5 * sc / (uc * usm)) * x)
c2 = 2 * (-uc ** 2 * (1 + 1 / usm) * cc - 0.5 * uc * sc / usm + vc + (1 + 1 / usm) * (2 * uc * cc + sc) * x - (
cc / usm + 0.5 * sc / (uc * usm)) * x ** 2)
c = (1 - c1 / c2) ** 2
return a + b + c
def left_constant_level() -> float:
return 1
def right_constant_level() -> float:
return 1
def _check_butterfly_arbitrage(cls, sc: float, pc: float, cc: float, dc: float, dsm: float, uc: float, usm: float,
x: float, vc: float) -> float:
"""检查是否存在蝶式价差套利机会,确保拟合time-slice iv-curve 是无套利(无蝶式价差静态套利)曲线
:param sc:
:param pc:
:param cc:
:param dc:
:param dsm:
:param uc:
:param usm:
:param x:
:param vc:
:return:
"""
# if x < dc * (1 + dsm):
# return cls.left_constant_level()
# elif dc * (1 + dsm) < x <= dc:
# return cls.left_smoothing_range(sc, pc, dc, dsm, x, vc)
# elif dc < x <= 0:
# return cls.left_parabolic(sc, pc, x, vc)
# elif 0 < x <= uc:
# return cls.right_parabolic(sc, cc, x, vc)
# elif uc < x <= uc * (1 + usm):
# return cls.right_smoothing_range(sc, cc, uc, usm, x, vc)
# elif uc * (1 + usm) < x:
# return cls.right_constant_level()
# else:
# raise ValueError("x value error!")
if dc < x <= 0:
return cls.left_parabolic(sc, pc, x, vc)
elif 0 < x <= uc:
return cls.right_parabolic(sc, cc, x, vc)
else:
return 0
def check_butterfly_arbitrage(cls, sc: float, pc: float, cc: float, dc: float, dsm: float, uc: float, usm: float,
moneyness: ndarray, vc: float) -> float:
"""
:param sc:
:param pc:
:param cc:
:param dc:
:param dsm:
:param uc:
:param usm:
:param moneyness:
:param vc:
:return:
"""
con_arr = []
for x in moneyness:
con_arr.append(cls._check_butterfly_arbitrage(sc, pc, cc, dc, dsm, uc, usm, x, vc))
con_arr = array(con_arr)
if (con_arr >= 0).all():
return minimum(con_arr.mean(), 1e-7)
else:
return maximum((con_arr[con_arr < 0]).mean(), -1e-7)
最简单情形的自动对冲
简介
本自动对冲是与策略下单相互独立,策略下单所需要的下单中保持对冲会集成在策略下单的代码中,本章的自动对冲适用于在未触发策略信号时的额外对冲操作,如果有其他自动交易的策略触发时,应该停止对应的合约的自动对冲操作。
选用近月或次月(可选)离合成期货最近的行权价对应的合成期货对作为对冲合约; 手动填入本次的目标$delta以及可以接受的一个上下范围区间,对冲必须要达到目标$delta一次后才会考虑上下范围; 例子:当前实际$delta为5%,目标$delta为20%,给定的上下容忍范围为10%,那么刚开始对冲的情形下,$delta必须到达一次20%才停止进入等待状态,后续才会判断$delta出了目标$delta20%上下10%去做对冲,也就是低于10%或高于30%.
$vega方面,不同于$delta我们给定的是目标$delta,$vega我们会给一个本次$vega,$vega我们会实时计算从本次对冲开始下的单对应的$vega,该$vega未到设定的本次要做的$vega前我们只用call或者put去做对冲,在该$vega达到本次要做的$vega后我们只会用合成期货去进行对冲;下的单对应的$vega到过一次本次要做的$vega后就再也不考虑$vega这一希腊值了
TODO:追价类型目前是'aT|b|c|M/C'的形式, 后续希望能够有更多的自定义操作,如作为买方时每次以min(对手价+2T, 前一笔+2T)的价格进行追价,待讨论
流程图
GUI代码部分
需要模块 icetcore, loguru
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : hedger_vanilla.py
@Time : 2023/02/15 20:55:33
@Author : DingWenjie
@Contact : 359582058@qq.com
@Desc : ### 独立自动对冲模块 ###
对冲合约: 平值合成期货
需要输入: 对冲月份: 近月/次月
目标delta: 百分值
目标delta容忍范围: 正百分值
本次对冲vega: 万分值
下单模式: 可选 本方价+1/对手价-1/中间价
相邻单间隔: int,秒
追价模式:
每笔单数: int
### 对冲逻辑 ###
以目标delta为主, 首次对冲(first_time)必须delta达到目标delta,
非首次时delta必须达到目标delta的容忍范围内,
vega则是以本次对冲为准, 不看总仓位的vega, 未达到本次对冲vega时,
只买(卖)不卖(买), 到达本次对冲vega后, 使用合成期货对冲
'''
import warnings
import datetime
import os
from loguru import logger
from tkinter import ttk, Tk, Label, Entry, StringVar, Button
from copy import deepcopy
from icetcore import TCoreAPI, QuoteEvent, TradeEvent, OrderStruct
from functools import partial
from threading import Thread
from time import sleep
now = datetime.datetime.now()
TODAY_STR = datetime.date.today().strftime('%Y%m%d')
if not os.path.exists('log'):
os.mkdir('log')
warnings.simplefilter('ignore')
logger.add(f"log/runtime_{now.strftime('%Y%m%d_%H_%M_%S')}.log")
type_list = ['本方+1', '中间价', '对手-1']
interval_list = ['1秒', '2秒', '3秒', '4秒', '5秒', '6秒',
'7秒', '8秒', '9秒', '10秒', '15秒', '20秒', '30秒']
qty_list = [str(i+1) for i in range(30)]
chase_list = ['1T|1|2|M', '1T|2|2|M', '2T|1|2|M', '2T|2|2|M', '1T|2|10|M']
class APIEvent(TradeEvent, QuoteEvent):
def __init__(self):
super().__init__()
def onconnected(self, apitype: str):
pass
def ondisconnected(self, apitype: str):
pass
def ongreeksreal(self, datatype, symbol, data):
global greeks_dict
try:
greeks_dict[data['Symbol']] = data
except:
return
def onbar(self, datatype, interval, symbol, data, isreal):
pass
def onfilledreportreal(self, data):
global filled_report_dict
temp_key = '.'.join(data['Symbol'].split('.')[
2:4])
if 'A' in temp_key:
temp_key = temp_key[:-1]
filled_report_dict[temp_key][data['Symbol'].split(
'.')[4]][data['DetailReportID']] = data
def onmargin(self, accmask, data):
global cash
global initial_cash
cash = data[0]['MarketPremium']
if not cash:
try:
cash = initial_cash
except NameError:
return
def onATM(self, datatype, symbol, data):
global atm_dict
try:
if data['ATM'] == '.'.join(data['OTM-1C'].split('.')[6:]):
atm_dict['.'.join(data['Symbol'].split('.')[2:5])] = {
'call': data['OTM-1C'],
'put': '.'.join(data['OTM-1C'].split('.')[:5])+'.P.'+data['ATM']
}
else:
atm_dict['.'.join(data['Symbol'].split('.')[2:5])] = {
'put': data['OTM-1P'],
'call': '.'.join(data['OTM-1P'].split('.')[:5])+'.C.'+data['ATM']
}
except:
return
def onpositionmoniter(self, data):
global und_list
global position_dict
try:
und_list
except:
return
try:
position_dict[und_list[0]]
except:
for und in und_list:
position_dict[und] = {
'Total': {'delta': 0, 'vega': 0},
month_list[0]: {'delta': 0, 'vega': 0},
month_list[1]: {'delta': 0, 'vega': 0}
}
for und in und_list:
for position in data:
if und in position['Symbol']:
position_dict[und][position['SubKey']
]['delta'] = position['$Delta']
position_dict[und][position['SubKey']
]['vega'] = position['$Vega']
class HedgerVanilla():
def __init__(self, account, brokerid):
self.account = account
self.brokerid = brokerid
logger.warning(f'当前登录的账户为{self.account},请确认后再继续!')
global cash
global initial_cash
cash = 0
while not cash:
cash = api.getaccmargin(
self.brokerid+'-'+self.account)['MarketPremium']
initial_cash = cash
@staticmethod
def _get_0and1_symbols_and_quote():
'''根据und_list订阅该标的对应的近月次月合约greeks及atm
'''
global und_list
global greeks_dict
global atm_dict
global symbol_dict
global month_list
atm_dict = {}
greeks_dict = {}
symbol_dict = {}
for und in und_list:
all_symbol = api.getallsymbol('OPT', und.split('.')[0])
symbol_dict[und] = all_symbol
month_list = list(set([int(symbol.split('.')[4])
for symbol in all_symbol]))
month_list.sort()
month_list = [str(month) for month in month_list[:2]]
for month in month_list:
api.subATM('TC.O.'+und+'.'+month+'.GET.ATM')
for symbol in all_symbol:
if und in symbol and (month_list[0] in symbol or month_list[1] in symbol):
api.subgreeksreal(symbol)
@staticmethod
def _calculate_cashvega_given_und(und, month):
'''根据已成交的单计算给定标的本次对冲(not vega_done)已做的cashvega
'''
global filled_report_dict
vega = 0
temp_symbol = {}
for order, data in filled_report_dict[und][month].items():
if data['Symbol'] not in temp_symbol:
temp_symbol[data['Symbol']] = int(
(data['Side']*(-2)+3)*data['MatchedQty'])
else:
temp_symbol[data['Symbol']
] += int((data['Side']*(-2)+3)*data['MatchedQty'])
for symbol, data in temp_symbol.items():
try:
vega += data*greeks_dict[symbol]['Vega']*10000
except TypeError:
print(f'data:{data}')
print('vega:', greeks_dict[symbol]['Vega']*10000)
print('\n'*50)
continue
return vega/cash*10000
@staticmethod
def _get_order_given_side_and_type(order, side, type_):
'''根据买卖方向和价格类型给出order_obj
'''
global order_obj
order_obj = deepcopy(order)
order_obj.Side = side
if side == 1:
if type_ == '本方+1':
order_obj.Price = 'BID+1T'
elif type_ == '对手-1':
order_obj.Price = 'ASK-1T'
else:
order_obj.OrderType = 15
order_obj.Synthetic = 1
else:
if type_ == '本方+1':
order_obj.Price = 'ASK-1T'
elif type_ == '对手-1':
order_obj.Price = 'BID+1T'
else:
order_obj.OrderType = 15
order_obj.Synthetic = 1
return order_obj
def _do_hedging_thread(self, und, month):
'''进行对冲子线程
'''
global hedging_state_dict
global filled_report_dict
global done_vega_dict
global atm_dict
global order_type_dict
global order_interval_dict
global order_qty_dict
global chase_type_dict
done_vega_dict[und][month] = 0
logger.debug(f'当前开始{und}_{month}的对冲子线程')
if globals()['delta_entry'].get() == '':
globals()[f'warning_tag_{month}']['text'] = '请输入目标delta!'
logger.debug(f'{und}_{month}对冲子线程因未输入终止')
hedging_state_dict[und][month] = 0
return
target_delta = float(globals()['delta_entry'].get())
if globals()['delta_tol'].get() == '':
globals()[f'warning_tag_{month}']['text'] = '请输入delta范围!'
logger.debug(f'{und}_{month}对冲子线程因未输入终止')
hedging_state_dict[und][month] = 0
return
delta_tol = float(globals()['delta_tol'].get())
if globals()[f'vega_entry_{month}'].get() == '':
globals()[f'warning_tag_{month}']['text'] = '请输入本次vega!'
logger.debug(f'{und}_{month}对冲子线程因未输入终止')
hedging_state_dict[und][month] = 0
return
globals()[f'warning_tag_{month}']['text'] = ''
target_vega = float(globals()[f'vega_entry_{month}'].get())
vega_done = False
first_time = True
delta_direction_for_first_time = 0
filled_report_dict[und][month] = {}
while hedging_state_dict[und][month]:
try:
temp_order_type = order_type_dict[und][month]
temp_order_interval = int(
order_interval_dict[und][month].split('秒')[0])
temp_order_qty = int(order_qty_dict[und][month])
temp_chase_type = chase_type_dict[und][month]
logger.debug(f'当前在{und}_{month}的对冲子线程中')
logger.debug(
f'目标delta为{target_delta}, 容忍范围是{delta_tol}, 要做的vega是{target_vega}, 月份为{month}, 下单间隔为{temp_order_interval}, 每笔手数为{temp_order_qty}, 下单类型为{temp_order_type}')
temp_delta = position_dict[und]['Total']['delta']
logger.debug(
f'当前{month}月份temp_delta为{temp_delta:.0f}, 当前目标delta为{target_delta/100*cash}')
hedger_symbol_call = atm_dict[und+'.'+month]['call']
hedger_symbol_put = atm_dict[und+'.'+month]['put']
if not vega_done: # 之前未做到过vega, 则先判断当前vega是否已到
done_vega_dict[und][month] = self._calculate_cashvega_given_und(
und, month)
logger.info(
f'当前已做的vega为 {done_vega_dict[und][month]: .1f}')
if abs(done_vega_dict[und][month]) >= abs(target_vega):
logger.success('当前首次到达目标vega, 后续做合成期货!')
vega_done = True
if first_time: # 若未到达过目标delta, 则判断所做delta方向(首次)/当前是否到达目标delta
if not delta_direction_for_first_time:
if temp_delta > target_delta/100*cash >= 0 or (temp_delta > target_delta/100*cash and 0 >= target_delta):
delta_direction_for_first_time = -1
else:
delta_direction_for_first_time = 1
logger.debug(
f'初次对冲, 判断对冲方向为{delta_direction_for_first_time}')
elif (temp_delta-target_delta/100*cash)*delta_direction_for_first_time > 0:
logger.info(
f'当前first_time,当前delta与目标delta之差为{temp_delta-target_delta/100*cash}')
first_time = False
continue
order_obj = OrderStruct(Account=self.account,
BrokerID=self.brokerid,
OrderQty=temp_order_qty,
OrderType=2, # 默认限价单
Symbol='',
Side=1,
TimeInForce=1,
PositionEffect=4,
SelfTradePrevention=3,
ChasePrice=temp_chase_type)
if (first_time and temp_delta < target_delta/100*cash) or (not first_time and temp_delta < (target_delta-delta_tol)/100*cash): # 做正delta
logger.info('判断本次做正delta')
if not vega_done: # 买call or 卖put
if target_vega > 0: # 买call
order = self._get_order_given_side_and_type(
order=order_obj, side=1, type_=temp_order_type)
order.Symbol = hedger_symbol_call
else: # 卖put
order = self._get_order_given_side_and_type(
order=order_obj, side=2, type_=temp_order_type)
order.Symbol = hedger_symbol_put
order_response, _ = api.neworder(order)
else: # 买合成期货
order_call = self._get_order_given_side_and_type(
order=order_obj, side=1, type_=temp_order_type)
order_call.Symbol = hedger_symbol_call
order_put = self._get_order_given_side_and_type(
order=order_obj, side=2, type_=temp_order_type)
order_put.Symbol = hedger_symbol_put
order_response_call, _ = api.neworder(order_call)
order_response_put, _ = api.neworder(order_put)
elif (first_time and temp_delta > target_delta/100*cash) or (not first_time and temp_delta > (target_delta+delta_tol)/100*cash): # 做负delta
logger.info('判断本次做负delta')
if not vega_done: # 卖call or 买put
if target_vega < 0: # 卖call
order = self._get_order_given_side_and_type(
order=order_obj, side=2, type_=temp_order_type)
order.Symbol = hedger_symbol_call
else: # 买put
order = self._get_order_given_side_and_type(
order=order_obj, side=1, type_=temp_order_type)
order.Symbol = hedger_symbol_put
order_response, _ = api.neworder(order)
else: # 卖合成期货
order_call = self._get_order_given_side_and_type(
order=order_obj, side=2, type_=temp_order_type)
order_call.Symbol = hedger_symbol_call
order_put = self._get_order_given_side_and_type(
order=order_obj, side=1, type_=temp_order_type)
order_put.Symbol = hedger_symbol_put
order_response_call, _ = api.neworder(order_call)
order_response_put, _ = api.neworder(order_put)
else:
logger.debug('当前delta满足要求, 不做对冲')
sleep(temp_order_interval)
except Exception as error:
logger.exception(error)
hedging_state_dict[und][month] = 0
logger.success(f'{und}对冲子线程已结束!')
filled_report_dict[und][month] = {}
done_vega_dict[und][month] = '-'
def _do_hedging(self, month):
'''tk.Button绑定开始对冲按键
'''
global hedging_state_dict
if hedging_state_dict[csd_und][month]:
logger.warning(f'{csd_und}_{month}已在对冲中')
return
global target_delta_dict
global target_delta_tol_dict
global target_vega_dict
global order_type_dict
global order_interval_dict
global order_qty_dict
global chase_type_dict
target_delta_dict[csd_und] = globals()['delta_entry'].get()
target_delta_tol_dict[csd_und] = globals()['delta_tol'].get()
target_vega_dict[csd_und][month] = globals()[
f'vega_entry_{month}'].get()
order_type_dict[csd_und][month] = globals(
)[f'csd_order_type_{month}'].get()
order_interval_dict[csd_und][month] = globals(
)[f'csd_order_interval_{month}'].get()
order_qty_dict[csd_und][month] = globals(
)[f'csd_order_qty_{month}'].get()
chase_type_dict[csd_und][month] = globals(
)[f'csd_chase_type_{month}'].get()
logger.debug(f'开始对冲{csd_und}-{month}')
hedging_state_dict[csd_und][month] = 1
globals()[f'hedging_subprocess_{csd_und}_{month}'] = Thread(
target=self._do_hedging_thread, args=(csd_und, month), daemon=True)
globals()[f'hedging_subprocess_{csd_und}_{month}'].start()
def _stop_hedging(self, month):
'''tk.Button绑定停止对冲按键
'''
global hedging_state_dict
if not hedging_state_dict[csd_und][month]:
logger.warning(f'{csd_und}_{month}未在对冲')
return
global target_delta_dict
global target_delta_tol_dict
global target_vega_dict
target_delta_dict[csd_und] = ''
target_delta_tol_dict[csd_und] = ''
target_vega_dict[csd_und][month] = ''
hedging_state_dict[csd_und][month] = 0
logger.debug(f'停止对冲{csd_und}_{month}')
def main_window(self):
'''GUI窗口
'''
def go(*args):
pass
global position_dict
global und_list
global month_list
global hedging_state_dict
global hedging_state_list
global done_vega_dict
global filled_report_dict
global order_report_dict
global csd_und
global target_delta_dict
global target_vega_dict
global target_delta_tol_dict
global order_type_dict
global order_interval_dict
global order_qty_dict
global chase_type_dict
chase_type_dict = {}
order_qty_dict = {}
order_interval_dict = {}
order_type_dict = {}
target_delta_tol_dict = {}
target_vega_dict = {}
target_delta_dict = {}
csd_und = und_list[0]
order_report_dict = {}
filled_report_dict = {}
done_vega_dict = {}
hedging_state_list = ['未运行', '对冲中']
hedging_state_dict = {}
position_dict = {}
self._get_0and1_symbols_and_quote()
sleep(3)
window = Tk()
window.title('ETF期权自动对冲初版')
account_tag = Label(text='当前账户:')
account_tag.grid(row=0, column=0)
account_tag_ = Label(text=self.account)
account_tag_.grid(row=0, column=1)
cash_tag = Label(text='当前资金:')
cash_tag.grid(row=1, column=0)
cash_tag_ = Label(text=f'{cash:.2f}')
cash_tag_.grid(row=1, column=1)
for i, und in enumerate(und_list):
order_report_dict[und] = {}
done_vega_dict[und] = {
month_list[0]: '',
month_list[1]: ''
}
hedging_state_dict[und] = {
month_list[0]: 0,
month_list[1]: 0
}
filled_report_dict[und] = {
month_list[0]: {},
month_list[1]: {}
}
target_delta_dict[und] = ''
target_delta_tol_dict[und] = ''
target_vega_dict[und] = {
month_list[0]: '',
month_list[1]: ''
}
order_type_dict[und] = {
month_list[0]: 1,
month_list[1]: 1
}
order_interval_dict[und] = {
month_list[0]: 1,
month_list[1]: 1,
}
order_qty_dict[und] = {
month_list[0]: 1,
month_list[1]: 1,
}
chase_type_dict[und] = {
month_list[0]: 0,
month_list[1]: 0,
}
crt_row = 2
text_tag = Label(text='当前总体$delta(百分之):')
text_tag.grid(row=crt_row, column=0)
try:
globals()['crt_delta_total'] = Label(
text=str(round(position_dict[csd_und]['Total']['delta']/cash*100, 1)), width=9)
except KeyError:
globals()['crt_delta_total'] = Label(
text=str('-'), width=8)
globals()['crt_delta_total'].grid(row=crt_row, column=1)
text_tag = Label(text='', width=8)
text_tag.grid(row=crt_row, column=2)
crt_row += 1
crt_vega = Label(text='当前总体$vega(万分之):')
crt_vega.grid(row=crt_row, column=0)
try:
globals()['crt_vega_total'] = Label(
text=str(round(position_dict[csd_und]['Total']['vega']/cash*10000, 1)), width=9)
except KeyError:
globals()['crt_vega_total'] = Label(
text=str('-'), width=8)
globals()['crt_vega_total'].grid(row=crt_row, column=1)
crt_row += 1
def choose_und(*args):
global csd_und
global target_delta_dict
global target_delta_tol_dict
global target_vega_dict
global order_type_dict
global order_interval_dict
global order_qty_dict
global chase_type_dict
csd_und = globals()['choose_csd_und'].get()
globals()['delta_entry'].delete(0, 5)
globals()['delta_entry'].insert(0, str(target_delta_dict[csd_und]))
globals()['delta_tol'].delete(0, 5)
globals()['delta_tol'].insert(
0, str(target_delta_tol_dict[csd_und]))
for month in month_list:
globals()[f'vega_entry_{month}'].delete(0, 5)
globals()[f'vega_entry_{month}'].insert(
0, str(target_vega_dict[csd_und][month]))
if hedging_state_dict[csd_und][month]:
globals()[f'csd_order_type_{month}'].current(
type_list.index(order_type_dict[csd_und][month]))
globals()[f'csd_order_interval_{month}'].current(
interval_list.index(order_interval_dict[csd_und][month]))
globals()[f'csd_order_qty_{month}'].current(
qty_list.index(order_qty_dict[csd_und][month]))
globals()[f'csd_chase_type_{month}'].current(
chase_list.index(chase_type_dict[csd_und][month]))
text_tag = Label(text='选择对冲合约:')
text_tag.grid(row=crt_row, column=0)
globals()['choose_csd_und'] = ttk.Combobox(
window, textvariable=StringVar(), width=11)
globals()['choose_csd_und']['values'] = und_list
globals()['choose_csd_und'].current(0)
globals()['choose_csd_und'].bind('<<ComboboxSelected>>', choose_und)
globals()['choose_csd_und'].grid(row=crt_row, column=1)
crt_row += 1
text_tag = Label(text='')
text_tag.grid(row=crt_row, column=0)
crt_row += 1
text_tag = Label(text='对冲设置')
text_tag.grid(row=crt_row, column=0)
fixed_row_num = crt_row
col_num = 4
for i, month in enumerate(month_list):
crt_row = fixed_row_num
text_tag = Label(text='考虑月份')
text_tag.grid(row=crt_row, column=i*col_num)
text_tag = Label(text=month, fg='red')
text_tag.grid(row=crt_row, column=i*col_num+1)
crt_row += 1
text_tag = Label(text='该月$delta(百分之):')
text_tag.grid(row=crt_row, column=i*col_num)
try:
globals()[f'crt_delta_{month}'] = Label(
text=str(round(position_dict[csd_und][month]['delta']/cash*100, 1)), width=9)
except KeyError:
globals()[f'crt_delta_{month}'] = Label(
text=str('-'), width=8)
globals()[f'crt_delta_{month}'].grid(
row=crt_row, column=1+i*col_num)
text_tag = Label(text='', width=8)
text_tag.grid(row=crt_row, column=2+i*col_num)
crt_row += 1
crt_vega = Label(text='该月$vega(万分之):')
crt_vega.grid(row=crt_row, column=i*col_num)
try:
globals()[f'crt_vega_{month}'] = Label(
text=str(round(position_dict[csd_und][month]['vega']/cash*10000, 1)), width=9)
except KeyError:
globals()[f'crt_vega_{month}'] = Label(
text=str('-'), width=8)
globals()[f'crt_vega_{month}'].grid(
row=crt_row, column=1+i*col_num)
crt_row += 1
globals()[f'warning_tag_{month}'] = Label(text='', fg='red')
globals()[f'warning_tag_{month}'].grid(
row=crt_row, column=1+i*col_num)
crt_row += 1
if not i:
text_tag = Label(text='目标总delta(百分之):')
text_tag.grid(row=crt_row, column=i*col_num)
globals()['delta_entry'] = Entry(window, width=10)
globals()['delta_entry'].grid(
row=crt_row, column=1+i*col_num)
crt_row += 1
if not i:
text_tag = Label(text='delta范围(百分之):')
text_tag.grid(row=crt_row, column=i*col_num)
globals()['delta_tol'] = Entry(window, width=10)
globals()['delta_tol'].grid(
row=crt_row, column=1+i*col_num)
crt_row += 1
text_tag = Label(text='本次vega(万分之):')
text_tag.grid(row=crt_row, column=i*col_num)
globals()[f'vega_entry_{month}'] = Entry(window, width=10)
globals()[f'vega_entry_{month}'].grid(
row=crt_row, column=1+i*col_num)
crt_row += 1
text_tag = Label(text='选择对冲下单模式:')
text_tag.grid(row=crt_row, column=i*col_num)
globals()[f'csd_order_type_{month}'] = ttk.Combobox(
window, textvariable=StringVar(), width=7)
globals()[f'csd_order_type_{month}']['values'] = type_list
globals()[f'csd_order_type_{month}'].current(0)
globals()[f'csd_order_type_{month}'].bind(
'<<ComboboxSelected>>', go)
globals()[f'csd_order_type_{month}'].grid(
row=crt_row, column=1+i*col_num)
crt_row += 1
text_tag = Label(text='选择下单相邻间隔:')
text_tag.grid(row=crt_row, column=i*col_num)
globals()[f'csd_order_interval_{month}'] = ttk.Combobox(
window, textvariable=StringVar(), width=7)
globals()[f'csd_order_interval_{month}']['values'] = interval_list
globals()[f'csd_order_interval_{month}'].current(1)
globals()[f'csd_order_interval_{month}'].bind(
'<<ComboboxSelected>>', go)
globals()[f'csd_order_interval_{month}'].grid(
row=crt_row, column=1+i*col_num)
crt_row += 1
text_tag = Label(text='对冲每笔下单单数:')
text_tag.grid(row=crt_row, column=i*col_num)
globals()[f'csd_order_qty_{month}'] = ttk.Combobox(
window, textvariable=StringVar(), width=7)
globals()[f'csd_order_qty_{month}']['values'] = qty_list
globals()[f'csd_order_qty_{month}'].current(0)
globals()[f'csd_order_qty_{month}'].bind(
'<<ComboboxSelected>>', go)
globals()[f'csd_order_qty_{month}'].grid(
row=crt_row, column=1+i*col_num)
crt_row += 1
text_tag = Label(text='选择对冲追价方式:')
text_tag.grid(row=crt_row, column=i*col_num)
globals()[f'csd_chase_type_{month}'] = ttk.Combobox(
window, textvariable=StringVar(), width=7)
globals()[f'csd_chase_type_{month}']['values'] = chase_list
globals()[f'csd_chase_type_{month}'].current(3)
globals()[f'csd_chase_type_{month}'].bind(
'<<ComboboxSelected>>', go)
globals()[f'csd_chase_type_{month}'].grid(
row=crt_row, column=1+i*col_num)
crt_row += 1
text_tag = Label(text='')
text_tag.grid(row=crt_row, column=i*col_num)
crt_row += 1
text_tag = Label(text='当前对冲状态:')
text_tag.grid(row=crt_row, column=i*col_num)
globals()[f'hedging_state_{month}'] = Label(
text=hedging_state_list[hedging_state_dict[und][month]], fg='red')
globals()[f'hedging_state_{month}'].grid(
row=crt_row, column=1+i*col_num)
crt_row += 1
text_tag = Label(text='当前已做vega:')
text_tag.grid(row=crt_row, column=i*col_num)
globals()[f'done_vega_{month}'] = Label(text='0')
globals()[f'done_vega_{month}'].grid(
row=crt_row, column=1+i*col_num)
crt_row += 1
Button(window, text="开始对冲", width=10, command=partial(
self._do_hedging, month)).grid(row=crt_row, column=i*col_num, padx=1, pady=1)
crt_row += 1
Button(window, text="停止对冲", width=10, command=partial(
self._stop_hedging, month)).grid(row=crt_row, column=i*col_num, padx=1, pady=1)
crt_row += 1
text_tag = Label(text='')
text_tag.grid(row=crt_row, column=0)
crt_row += 1
text_tag = Label(text='注意事项:', fg='red')
text_tag.grid(row=crt_row, column=0)
crt_row += 1
text_tag = Label(
text='目标delta, delta范围, 本次vega三个输出栏不可在对冲中进行更改!(待更改)', fg='red')
text_tag.grid(row=crt_row, column=0, columnspan=6)
crt_row += 1
text_tag = Label(
text='追价方式参数为"递增tick|追几次|每次几秒|M"', fg='red')
text_tag.grid(row=crt_row, column=0, columnspan=5)
def update():
while True:
sleep(0.5)
cash_tag_['text'] = f'{cash:.2f}'
try:
globals()['crt_delta_total']['text'] = str(
round(position_dict[csd_und]['Total']['delta']/cash*100, 1))
except KeyError:
globals()['crt_delta_total']['text'] = str('-')
try:
globals()['crt_vega_total']['text'] = str(
round(position_dict[csd_und]['Total']['vega']/cash*10000, 1))
except KeyError:
globals()['crt_vega_total']['text'] = str('-')
for month in month_list:
try:
globals()[f'crt_delta_{month}']['text'] = str(
round(position_dict[csd_und][month]['delta']/cash*100, 1))
except KeyError:
globals()[f'crt_delta_{month}']['text'] = str('-')
try:
globals()[f'crt_vega_{month}']['text'] = str(
round(position_dict[csd_und][month]['vega']/cash*10000, 1))
except KeyError:
globals()[f'crt_vega_{month}']['text'] = str('-')
globals()[
f'hedging_state_{month}']['text'] = hedging_state_list[hedging_state_dict[csd_und][month]]
try:
globals()[f'done_vega_{month}']['text'] = str(
round(done_vega_dict[csd_und][month], 1))
except TypeError:
globals()[f'done_vega_{month}']['text'] = '-'
if hedging_state_dict[csd_und][month]:
order_type_dict[csd_und][month] = globals(
)[f'csd_order_type_{month}'].get()
order_interval_dict[csd_und][month] = globals(
)[f'csd_order_interval_{month}'].get()
order_qty_dict[csd_und][month] = globals(
)[f'csd_order_qty_{month}'].get()
chase_type_dict[csd_und][month] = globals(
)[f'csd_chase_type_{month}'].get()
def monitor(window):
window.after(100, update())
update_thread = Thread(
target=update, daemon=True, name='update_thread')
monitor_thread = Thread(
target=monitor, args=(window), daemon=True, name='monitor_thread')
update_thread.start()
monitor_thread.start()
window.mainloop()
if __name__ == '__main__':
api = TCoreAPI(APIEvent)
re = api.connect()
account_info = api.getaccountlist()
account = account_info[0]['Account']
brokerid = account_info[0]['BrokerID']
global und_list
und_list = ['SSE.510050', 'SSE.510300', 'SSE.510500', 'SZSE.159915']
hedger = HedgerVanilla(account=account, brokerid=brokerid)
hedger.main_window()
打散:例如对二分类问题来说,m个样本最多有2^m个可能结果,每种可能结果称为一种**“对分”**,若假设空间能实现数据集D的所有对分,则称数据集能被该假设空间打散。